{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "train_models_with_tensorFlow_decision_forests.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "36EdAGhThQov"
   },
   "source": [
    "# Building, Training and Evaluating Models with TensorFlow Decision Forests\n",
    "\n",
    "## Overview\n",
    "\n",
    "In this lab, you use TensorFlow Decision Forests (TF-DF) library for the training, evaluation, interpretation and inference of Decision Forest models.\n",
    "\n",
    "## Learning Objective\n",
    "\n",
    "In this notebook, you learn how to:\n",
    "\n",
    "1. Train a binary classification Random Forest on a dataset containing numerical, categorical and missing features.\n",
    "2. Evaluate the model on a test dataset and prepare the model for [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving).\n",
    "3. Examine the overall structure of the model and the importance of each feature.\n",
    "4. Re-train the model with a different learning algorithm (Gradient Boosted Decision Trees) and use a different set of input features.\n",
    "5. Change the hyperparameters of the model.\n",
    "6. Preprocess the features and train a model for regression.\n",
    "7. Train a model for ranking.\n",
    "\n",
    "## Introduction\n",
    "\n",
    "This tutorial shows how to use TensorFlow Decision Forests (TF-DF) library for the training, evaluation, interpretation and inference of Decision Forest models.\n",
    "\n",
    "Decision Forests (DF) are a large family of Machine Learning algorithms for supervised classification, regression and ranking. As the name suggests, DFs use decision trees as a building block. Today, the two most popular DF training algorithms are [Random Forests](https://en.wikipedia.org/wiki/Random_forest) and [Gradient Boosted Decision Trees](https://en.wikipedia.org/wiki/Gradient_boosting). Both algorithms are ensemble techniques that use multiple decision trees, but differ on how they do it.\n",
    "\n",
    "Each learning objective will correspond to a __#TODO__ in this student lab notebook -- try to complete this notebook first and then review the [solution notebook](../solutions/train_models_with_tensorFlow_decision_forests.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jK9tCTcwqq4k"
   },
   "source": [
    "## Installing TensorFlow Decision Forests\n",
    "\n",
    "Install TF-DF by running the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "Pa1Pf37RhEYN"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
       "Collecting tensorflow_decision_forests\n",
       "  Downloading tensorflow_decision_forests-0.2.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (17.7 MB)\n",
       "     |████████████████████████████████| 17.7 MB 6.4 MB/s            \n",
       "\u001b[?25hCollecting wurlitzer\n",
       "  Downloading wurlitzer-3.0.2-py3-none-any.whl (7.3 kB)\n",
       "Requirement already satisfied: absl-py in /opt/conda/lib/python3.7/site-packages (from tensorflow_decision_forests) (0.15.0)\n",
       "Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from tensorflow_decision_forests) (1.16.0)\n",
       "Collecting tensorflow~=2.7.0\n",
       "  Downloading tensorflow-2.7.0-cp37-cp37m-manylinux2010_x86_64.whl (489.6 MB)\n",
       "     |████████████████████████████████| 489.6 MB 15 kB/s               \n",
       "\u001b[?25hRequirement already satisfied: wheel in /opt/conda/lib/python3.7/site-packages (from tensorflow_decision_forests) (0.37.0)\n",
       "Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from tensorflow_decision_forests) (1.19.5)\n",
       "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from tensorflow_decision_forests) (1.3.5)\n",
       "Requirement already satisfied: astunparse>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (1.6.3)\n",
       "Collecting tensorflow-io-gcs-filesystem>=0.21.0\n",
       "  Downloading tensorflow_io_gcs_filesystem-0.23.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (2.1 MB)\n",
       "     |████████████████████████████████| 2.1 MB 55.7 MB/s            \n",
       "\u001b[?25hRequirement already satisfied: tensorboard~=2.6 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (2.6.0)\n",
       "Collecting libclang>=9.0.1\n",
       "  Downloading libclang-12.0.0-py2.py3-none-manylinux1_x86_64.whl (13.4 MB)\n",
       "     |████████████████████████████████| 13.4 MB 54.2 MB/s            \n",
       "\u001b[?25hRequirement already satisfied: termcolor>=1.1.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (1.1.0)\n",
       "Requirement already satisfied: protobuf>=3.9.2 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (3.19.1)\n",
       "Requirement already satisfied: keras-preprocessing>=1.1.1 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (1.1.2)\n",
       "Requirement already satisfied: google-pasta>=0.1.1 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (0.2.0)\n",
       "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (1.43.0)\n",
       "Requirement already satisfied: opt-einsum>=2.3.2 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (3.3.0)\n",
       "Collecting keras<2.8,>=2.7.0rc0\n",
       "  Downloading keras-2.7.0-py2.py3-none-any.whl (1.3 MB)\n",
       "     |████████████████████████████████| 1.3 MB 67.4 MB/s            \n",
       "\u001b[?25hRequirement already satisfied: typing-extensions>=3.6.6 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (4.0.1)\n",
       "Requirement already satisfied: h5py>=2.9.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (3.1.0)\n",
       "Collecting tensorflow-estimator<2.8,~=2.7.0rc0\n",
       "  Downloading tensorflow_estimator-2.7.0-py2.py3-none-any.whl (463 kB)\n",
       "     |████████████████████████████████| 463 kB 64.7 MB/s            \n",
       "\u001b[?25hRequirement already satisfied: gast<0.5.0,>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (0.4.0)\n",
       "Requirement already satisfied: wrapt>=1.11.0 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (1.13.3)\n",
       "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /opt/conda/lib/python3.7/site-packages (from tensorflow~=2.7.0->tensorflow_decision_forests) (1.12)\n",
       "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->tensorflow_decision_forests) (2.8.2)\n",
       "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas->tensorflow_decision_forests) (2021.3)\n",
       "Requirement already satisfied: cached-property in /opt/conda/lib/python3.7/site-packages (from h5py>=2.9.0->tensorflow~=2.7.0->tensorflow_decision_forests) (1.5.2)\n",
       "Collecting google-auth<2,>=1.6.3\n",
       "  Downloading google_auth-1.35.0-py2.py3-none-any.whl (152 kB)\n",
       "     |████████████████████████████████| 152 kB 67.7 MB/s            \n",
       "\u001b[?25hRequirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (0.4.6)\n",
       "Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (59.6.0)\n",
       "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (1.8.0)\n",
       "Requirement already satisfied: requests<3,>=2.21.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (2.26.0)\n",
       "Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (2.0.2)\n",
       "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (0.6.1)\n",
       "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.7/site-packages (from tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (3.3.6)\n",
       "Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (4.8)\n",
       "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (0.2.7)\n",
       "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (4.2.4)\n",
       "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (1.3.0)\n",
       "Requirement already satisfied: importlib-metadata>=4.4 in /opt/conda/lib/python3.7/site-packages (from markdown>=2.6.8->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (4.9.0)\n",
       "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (2021.10.8)\n",
       "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (3.1)\n",
       "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (1.26.7)\n",
       "Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (2.0.9)\n",
       "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (3.6.0)\n",
       "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (0.4.8)\n",
       "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow~=2.7.0->tensorflow_decision_forests) (3.1.1)\n",
       "Installing collected packages: google-auth, tensorflow-io-gcs-filesystem, tensorflow-estimator, libclang, keras, wurlitzer, tensorflow, tensorflow-decision-forests\n",
       "  Attempting uninstall: google-auth\n",
       "    Found existing installation: google-auth 2.3.3\n",
       "    Uninstalling google-auth-2.3.3:\n",
       "      Successfully uninstalled google-auth-2.3.3\n",
       "  Attempting uninstall: tensorflow-estimator\n",
       "    Found existing installation: tensorflow-estimator 2.6.0\n",
       "    Uninstalling tensorflow-estimator-2.6.0:\n",
       "      Successfully uninstalled tensorflow-estimator-2.6.0\n",
       "  Attempting uninstall: keras\n",
       "    Found existing installation: keras 2.6.0\n",
       "    Uninstalling keras-2.6.0:\n",
       "      Successfully uninstalled keras-2.6.0\n",
       "  Attempting uninstall: tensorflow\n",
       "    Found existing installation: tensorflow 2.6.2\n",
       "    Uninstalling tensorflow-2.6.2:\n",
       "      Successfully uninstalled tensorflow-2.6.2\n",
       "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
       "explainable-ai-sdk 1.3.2 requires xai-image-widget, which is not installed.\n",
       "tfx-bsl 1.5.0 requires absl-py<0.13,>=0.9, but you have absl-py 0.15.0 which is incompatible.\n",
       "tfx-bsl 1.5.0 requires google-api-python-client<2,>=1.7.11, but you have google-api-python-client 2.33.0 which is incompatible.\n",
       "tfx-bsl 1.5.0 requires pyarrow<6,>=1, but you have pyarrow 6.0.1 which is incompatible.\n",
       "tensorflow-transform 1.5.0 requires absl-py<0.13,>=0.9, but you have absl-py 0.15.0 which is incompatible.\n",
       "tensorflow-transform 1.5.0 requires pyarrow<6,>=1, but you have pyarrow 6.0.1 which is incompatible.\n",
       "tensorflow-io 0.21.0 requires tensorflow<2.7.0,>=2.6.0, but you have tensorflow 2.7.0 which is incompatible.\n",
       "tensorflow-io 0.21.0 requires tensorflow-io-gcs-filesystem==0.21.0, but you have tensorflow-io-gcs-filesystem 0.23.1 which is incompatible.\n",
       "cloud-tpu-client 0.10 requires google-api-python-client==1.8.0, but you have google-api-python-client 2.33.0 which is incompatible.\u001b[0m\n",
       "Successfully installed google-auth-1.35.0 keras-2.7.0 libclang-12.0.0 tensorflow-2.7.0 tensorflow-decision-forests-0.2.2 tensorflow-estimator-2.7.0 tensorflow-io-gcs-filesystem-0.23.1 wurlitzer-3.0.2\n"
     ]
    }
   ],
   "source": [
    "# Install the specified package\n",
    "!pip install tensorflow_decision_forests"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jK9tCTcwqq4k"
   },
   "source": [
    "**Please ignore incompatible errors.**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vZGda2dOe-hH"
   },
   "source": [
    "Install [Wurlitzer](https://pypi.org/project/wurlitzer/) to display\n",
    "the detailed training logs. This is only needed in colabs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "lk26uBSCe8Du"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: wurlitzer in /opt/conda/lib/python3.7/site-packages (3.0.2)\n"
     ]
    }
   ],
   "source": [
    "# Install the specified package\n",
    "!pip install wurlitzer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3oinwbhXlggd"
   },
   "source": [
    "**Note:** Please restart the kernel by clicking **Kernel > Restart Kernel**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3oinwbhXlggd"
   },
   "source": [
    "## Importing libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "52W45tmDjD64"
   },
   "outputs": [],
   "source": [
    "# Import necessary libraries\n",
    "import tensorflow_decision_forests as tfdf\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "import math\n",
    "\n",
    "try:\n",
    "  from wurlitzer import sys_pipes\n",
    "except:\n",
    "  from colabtools.googlelog import CaptureLog as sys_pipes\n",
    "\n",
    "from IPython.core.magic import register_line_magic\n",
    "from IPython.display import Javascript"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0LPPwWxYxtDM"
   },
   "source": [
    "The hidden code cell limits the output height in colab.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "cellView": "form",
    "id": "2AhqJz3VmQM-"
   },
   "outputs": [],
   "source": [
    "# Some of the model training logs can cover the full\n",
    "# screen if not compressed to a smaller viewport.\n",
    "# This magic allows setting a max height for a cell.\n",
    "@register_line_magic\n",
    "def set_cell_height(size):\n",
    "  display(\n",
    "      Javascript(\"google.colab.output.setIframeHeight(0, true, {maxHeight: \" +\n",
    "                 str(size) + \"})\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "8gVQ-txtjFU4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found TensorFlow Decision Forests v0.2.2\n"
     ]
    }
   ],
   "source": [
    "# Check the version of TensorFlow Decision Forests\n",
    "print(\"Found TensorFlow Decision Forests v\" + tfdf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QGRtRECujKeu"
   },
   "source": [
    "## Training a Random Forest model\n",
    "\n",
    "In this section, we train, evaluate, analyse and export a binary classification Random Forest trained on the [Palmer's Penguins](https://allisonhorst.github.io/palmerpenguins/articles/intro.html) dataset.\n",
    "\n",
    "<center>\n",
    "<img src=\"https://allisonhorst.github.io/palmerpenguins/man/figures/palmerpenguins.png\" width=\"150\"/></center>\n",
    "\n",
    "**Note:** The dataset was exported to a csv file without pre-processing: `library(palmerpenguins); write.csv(penguins, file=\"penguins_toy.csv\", quote=F, row.names=F)`. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3qsSU1RfmNiP"
   },
   "source": [
    "### Load the dataset and convert it in a tf.Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9nJ5igfElg2I"
   },
   "source": [
    "This dataset is very small (300 examples) and stored as a .csv-like file. Therefore, use Pandas to load it.\n",
    "\n",
    "**Note:** Pandas is practical as you don't have to type in name of the input features to load them. For larger datasets (>1M examples), using the\n",
    "[TensorFlow Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) to read the files may be better suited.\n",
    "\n",
    "Let's assemble the dataset into a csv file (i.e. add the header), and load it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "44Jq6g_mJFmj"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Copying gs://cloud-training/mlongcp/v3.0_MLonGC/toy_data/penguins_toy.csv...\n",
      "/ [1 files][  7.3 KiB/  7.3 KiB]                                                \n",
      "Operation completed over 1 objects/7.3 KiB.                                      \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>species</th>\n",
       "      <th>island</th>\n",
       "      <th>bill_length_mm</th>\n",
       "      <th>bill_depth_mm</th>\n",
       "      <th>flipper_length_mm</th>\n",
       "      <th>body_mass_g</th>\n",
       "      <th>sex</th>\n",
       "      <th>year</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Adelie</td>\n",
       "      <td>Torgersen</td>\n",
       "      <td>39.1</td>\n",
       "      <td>18.7</td>\n",
       "      <td>181.0</td>\n",
       "      <td>3750.0</td>\n",
       "      <td>male</td>\n",
       "      <td>2007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Adelie</td>\n",
       "      <td>Torgersen</td>\n",
       "      <td>39.5</td>\n",
       "      <td>17.4</td>\n",
       "      <td>186.0</td>\n",
       "      <td>3800.0</td>\n",
       "      <td>female</td>\n",
       "      <td>2007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Adelie</td>\n",
       "      <td>Torgersen</td>\n",
       "      <td>40.3</td>\n",
       "      <td>18.0</td>\n",
       "      <td>195.0</td>\n",
       "      <td>3250.0</td>\n",
       "      <td>female</td>\n",
       "      <td>2007</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  species     island  bill_length_mm  bill_depth_mm  flipper_length_mm  \\\n",
       "0  Adelie  Torgersen            39.1           18.7              181.0   \n",
       "1  Adelie  Torgersen            39.5           17.4              186.0   \n",
       "2  Adelie  Torgersen            40.3           18.0              195.0   \n",
       "\n",
       "   body_mass_g     sex  year  \n",
       "0       3750.0    male  2007  \n",
       "1       3800.0  female  2007  \n",
       "2       3250.0  female  2007  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Download the dataset\n",
    "!gcloud storage cp gs://cloud-training/mlongcp/v3.0_MLonGC/toy_data/penguins_toy.csv /tmp/penguins.csv\n",
    "\n",
    "# Load a dataset into a Pandas Dataframe.\n",
    "dataset_df = pd.read_csv(\"/tmp/penguins.csv\")\n",
    "\n",
    "# Display the first 3 examples.\n",
    "dataset_df.head(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "23AewWT1lkIK"
   },
   "source": [
    "The dataset contains a mix of numerical (e.g. `bill_depth_mm`), categorical\n",
    "(e.g. `island`) and missing features. TF-DF supports all these feature types natively (differently than NN based models), therefore there is no need for preprocessing in the form of one-hot encoding, normalization or extra `is_present` feature.\n",
    "\n",
    "Labels are a bit different: Keras metrics expect integers. The label (`species`) is stored as a string, so let's convert it into an integer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "uO_jz2sj0IBZ"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Label classes: ['Adelie', 'Gentoo']\n"
     ]
    }
   ],
   "source": [
    "# Encode the categorical label into an integer.\n",
    "#\n",
    "# Details:\n",
    "# This stage is necessary if your classification label is represented as a\n",
    "# string. Note: Keras expected classification labels to be integers.\n",
    "\n",
    "# Name of the label column.\n",
    "label = \"species\"\n",
    "\n",
    "classes = dataset_df[label].unique().tolist()\n",
    "print(f\"Label classes: {classes}\")\n",
    "\n",
    "dataset_df[label] = dataset_df[label].map(classes.index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vwJjLFhbtozI"
   },
   "source": [
    "Next split the dataset into training and testing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "u7DEIxn2oB3U"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "114 examples in training, 55 examples for testing.\n"
     ]
    }
   ],
   "source": [
    "# Split the dataset into a training and a testing dataset.\n",
    "\n",
    "def split_dataset(dataset, test_ratio=0.30):\n",
    "  \"\"\"Splits a panda dataframe in two.\"\"\"\n",
    "  test_indices = np.random.rand(len(dataset)) < test_ratio\n",
    "  return dataset[~test_indices], dataset[test_indices]\n",
    "\n",
    "\n",
    "train_ds_pd, test_ds_pd = split_dataset(dataset_df)\n",
    "print(\"{} examples in training, {} examples for testing.\".format(\n",
    "    len(train_ds_pd), len(test_ds_pd)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uWq7uQcCuBzO"
   },
   "source": [
    "And finally, convert the pandas dataframe (`pd.Dataframe`) into tensorflow datasets (`tf.data.Dataset`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "qtXgUBKluTX0"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
       "/opt/conda/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:2038: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only\n",
       "  features_dataframe = dataframe.drop(label, 1)\n",
       "2021-12-27 10:55:01.056823: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64\n",
       "2021-12-27 10:55:01.056870: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)\n",
       "2021-12-27 10:55:01.056895: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (tensorflow-2-6-20211227-161231): /proc/driver/nvidia/version does not exist\n",
       "2021-12-27 10:55:01.057333: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
       "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label)\n",
    "test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BRKLWIWNuOZ1"
   },
   "source": [
    "**Notes:** `pd_dataframe_to_tf_dataset` could have converted the label to integer for you.\n",
    "\n",
    "And, if you wanted to create the `tf.data.Dataset` yourself, there is a couple of things to remember:\n",
    "\n",
    "- The learning algorithms work with a one-epoch dataset and without shuffling.\n",
    "- The batch size does not impact the training algorithm, but a small value might slow down reading the dataset.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mYAoyfYtqHG4"
   },
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "xete-FbuqJCV"
   },
   "outputs": [
     {
      "data": {
       "application/javascript": [
        "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})"
       ],
       "text/plain": [
        "<IPython.core.display.Javascript object>"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Use /tmp/tmpsu6tyqbg as temporary training directory\n",
       "Starting reading the dataset\n",
       "1/1 [==============================] - ETA: 0s\n",
       "Dataset read in 0:00:04.780772\n",
       "Training model\n",
       "Model trained in 0:00:00.017295\n",
       "Compiling model\n",
       "1/1 [==============================] - 5s 5s/step\n"
      ]
     }
   ],
   "source": [
    "%set_cell_height 300\n",
    "\n",
    "# Specify the model.\n",
    "model_1 = tfdf.keras.RandomForestModel()\n",
    "\n",
    "# Optionally, add evaluation metrics.\n",
    "model_1.compile(\n",
    "    metrics=[\"accuracy\"])\n",
    "\n",
    "# Train the model.\n",
    "# \"sys_pipes\" is optional. It enables the display of the training logs.\n",
    "# TODO\n",
    "with sys_pipes():\n",
    "  model_1.fit(x=train_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OBnjxdip-MC0"
   },
   "source": [
    "### Remarks\n",
    "\n",
    "-   No input features are specified. Therefore, all the columns will be used as\n",
    "    input features except for the label. The feature used by the model are shown\n",
    "    in the training logs and in the `model.summary()`.\n",
    "-   DFs consume natively numerical, categorical, categorical-set features and\n",
    "    missing-values. Numerical features do not need to be normalized. Categorical\n",
    "    string values do not need to be encoded in a dictionary.\n",
    "-   No training hyper-parameters are specified. Therefore the default\n",
    "    hyper-parameters will be used. Default hyper-parameters provide\n",
    "    reasonable results in most situations.\n",
    "-   Calling `compile` on the model before the `fit` is optional. Compile can be\n",
    "    used to provide extra evaluation metrics.\n",
    "-   Training algorithms do not need validation datasets. If a validation dataset\n",
    "    is provided, it will only be used to show metrics.\n",
    "\n",
    "**Note:** A *Categorical-Set* feature is composed of a set of categorical values (while a *Categorical* is only one value). More details and examples are given later."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tSdtNJUArBpl"
   },
   "source": [
    "## Evaluate the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Udtu_uS1paSu"
   },
   "source": [
    "Let's evaluate our model on the test dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "xUy4ULEMtDXB"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
       "1/1 [==============================] - 0s 340ms/step - loss: 0.0000e+00 - accuracy: 1.0000\n",
       "\n",
       "loss: 0.0000\n",
       "accuracy: 1.0000\n"
     ]
    }
   ],
   "source": [
    "# TODO\n",
    "# Evaluate the model\n",
    "evaluation = model_1.evaluate(test_ds, return_dict=True)\n",
    "print()\n",
    "\n",
    "for name, value in evaluation.items():\n",
    "  print(f\"{name}: {value:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tlhfzZ34pfO4"
   },
   "source": [
    "**Remark:** The test accuracy is close to the Out-of-bag accuracy\n",
    "shown in the training logs.\n",
    "\n",
    "See the **Model Self Evaluation** section below for more evaluation methods."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mHBFtUeElRYz"
   },
   "source": [
    "## Prepare this model for TensorFlow Serving."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JbC4lmgfr5Sm"
   },
   "source": [
    "Export the model to the SavedModel format for later re-use e.g.\n",
    "[TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "08YWGr9U2fza"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-10-28 12:21:12.212876: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_saved_model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/my_saved_model/assets\n"
     ]
    }
   ],
   "source": [
    "# Save the model\n",
    "model_1.save(\"/tmp/my_saved_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6-8R02_SXpbq"
   },
   "source": [
    "## Plot the model\n",
    "\n",
    "Plotting a decision tree and following the first branches helps learning about decision forests. In some cases, plotting a model can even be used for debugging.\n",
    "\n",
    "Because of the difference in the way they are trained, some models are more interresting to plan than others. Because of the noise injected during training and the depth of the trees, plotting Random Forest is less informative than plotting a CART or the first tree of a Gradient Boosted Tree.\n",
    "\n",
    "Never the less, let's plot the first tree of our Random Forest model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "KUIxf8N6Yjl0"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<script src=\"https://d3js.org/d3.v6.min.js\"></script>\n",
       "<div id=\"tree_plot_ff8b6abb7d5142889bd482156927e9db\"></div>\n",
       "<script>\n",
       "/*\n",
       " * Copyright 2021 Google LLC.\n",
       " * Licensed under the Apache License, Version 2.0 (the \"License\");\n",
       " * you may not use this file except in compliance with the License.\n",
       " * You may obtain a copy of the License at\n",
       " *\n",
       " *     https://www.apache.org/licenses/LICENSE-2.0\n",
       " *\n",
       " * Unless required by applicable law or agreed to in writing, software\n",
       " * distributed under the License is distributed on an \"AS IS\" BASIS,\n",
       " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
       " * See the License for the specific language governing permissions and\n",
       " * limitations under the License.\n",
       " */\n",
       "\n",
       "/**\n",
       " *  Plotting of decision trees generated by TF-DF.\n",
       " *\n",
       " *  A tree is a recursive structure of node objects.\n",
       " *  A node contains one or more of the following components:\n",
       " *\n",
       " *    - A value: Representing the output of the node. If the node is not a leaf,\n",
       " *      the value is only present for analysis i.e. it is not used for\n",
       " *      predictions.\n",
       " *\n",
       " *    - A condition : For non-leaf nodes, the condition (also known as split)\n",
       " *      defines a binary test to branch to the positive or negative child.\n",
       " *\n",
       " *    - An explanation: Generally a plot showing the relation between the label\n",
       " *      and the condition to give insights about the effect of the condition.\n",
       " *\n",
       " *    - Two children : For non-leaf nodes, the children nodes. The first\n",
       " *      children (i.e. \"node.children[0]\") is the negative children (drawn in\n",
       " *      red). The second children is the positive one (drawn in green).\n",
       " *\n",
       " */\n",
       "\n",
       "/**\n",
       " * Plots a single decision tree into a DOM element.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {!tree} raw_tree Recursive tree structure.\n",
       " * @param {string} canvas_id Id of the output dom element.\n",
       " */\n",
       "function display_tree(options, raw_tree, canvas_id) {\n",
       "  console.log(options);\n",
       "\n",
       "  // Determine the node placement.\n",
       "  const tree_struct = d3.tree().nodeSize(\n",
       "      [options.node_y_offset, options.node_x_offset])(d3.hierarchy(raw_tree));\n",
       "\n",
       "  // Boundaries of the node placement.\n",
       "  let x_min = Infinity;\n",
       "  let x_max = -x_min;\n",
       "  let y_min = Infinity;\n",
       "  let y_max = -x_min;\n",
       "\n",
       "  tree_struct.each(d => {\n",
       "    if (d.x > x_max) x_max = d.x;\n",
       "    if (d.x < x_min) x_min = d.x;\n",
       "    if (d.y > y_max) y_max = d.y;\n",
       "    if (d.y < y_min) y_min = d.y;\n",
       "  });\n",
       "\n",
       "  // Size of the plot.\n",
       "  const width = y_max - y_min + options.node_x_size + options.margin * 2;\n",
       "  const height = x_max - x_min + options.node_y_size + options.margin * 2 +\n",
       "      options.node_y_offset - options.node_y_size;\n",
       "\n",
       "  const plot = d3.select(canvas_id);\n",
       "\n",
       "  // Tool tip\n",
       "  options.tooltip = plot.append('div')\n",
       "                        .attr('width', 100)\n",
       "                        .attr('height', 100)\n",
       "                        .style('padding', '4px')\n",
       "                        .style('background', '#fff')\n",
       "                        .style('box-shadow', '4px 4px 0px rgba(0,0,0,0.1)')\n",
       "                        .style('border', '1px solid black')\n",
       "                        .style('font-family', 'sans-serif')\n",
       "                        .style('font-size', options.font_size)\n",
       "                        .style('position', 'absolute')\n",
       "                        .style('z-index', '10')\n",
       "                        .attr('pointer-events', 'none')\n",
       "                        .style('display', 'none');\n",
       "\n",
       "  // Create canvas\n",
       "  const svg = plot.append('svg').attr('width', width).attr('height', height);\n",
       "  const graph =\n",
       "      svg.style('overflow', 'visible')\n",
       "          .append('g')\n",
       "          .attr('font-family', 'sans-serif')\n",
       "          .attr('font-size', options.font_size)\n",
       "          .attr(\n",
       "              'transform',\n",
       "              () => `translate(${options.margin},${\n",
       "                  - x_min + options.node_y_offset / 2 + options.margin})`);\n",
       "\n",
       "  // Plot bounding box.\n",
       "  if (options.show_plot_bounding_box) {\n",
       "    svg.append('rect')\n",
       "        .attr('width', width)\n",
       "        .attr('height', height)\n",
       "        .attr('fill', 'none')\n",
       "        .attr('stroke-width', 1.0)\n",
       "        .attr('stroke', 'black');\n",
       "  }\n",
       "\n",
       "  // Draw the edges.\n",
       "  display_edges(options, graph, tree_struct);\n",
       "\n",
       "  // Draw the nodes.\n",
       "  display_nodes(options, graph, tree_struct);\n",
       "}\n",
       "\n",
       "/**\n",
       " * Draw the nodes of the tree.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {!graph} graph D3 search handle containing the graph.\n",
       " * @param {!tree_struct} tree_struct Structure of the tree (node placement,\n",
       " *     data, etc.).\n",
       " */\n",
       "function display_nodes(options, graph, tree_struct) {\n",
       "  const nodes = graph.append('g')\n",
       "                    .selectAll('g')\n",
       "                    .data(tree_struct.descendants())\n",
       "                    .join('g')\n",
       "                    .attr('transform', d => `translate(${d.y},${d.x})`);\n",
       "\n",
       "  nodes.append('rect')\n",
       "      .attr('x', 0.5)\n",
       "      .attr('y', 0.5)\n",
       "      .attr('width', options.node_x_size)\n",
       "      .attr('height', options.node_y_size)\n",
       "      .attr('stroke', 'lightgrey')\n",
       "      .attr('stroke-width', 1)\n",
       "      .attr('fill', 'white')\n",
       "      .attr('y', -options.node_y_size / 2);\n",
       "\n",
       "  // Brackets on the right of condition nodes without children.\n",
       "  non_leaf_node_without_children =\n",
       "      nodes.filter(node => node.data.condition != null && node.children == null)\n",
       "          .append('g')\n",
       "          .attr('transform', `translate(${options.node_x_size},0)`);\n",
       "\n",
       "  non_leaf_node_without_children.append('path')\n",
       "      .attr('d', 'M0,0 C 10,0 0,10 10,10')\n",
       "      .attr('fill', 'none')\n",
       "      .attr('stroke-width', 1.0)\n",
       "      .attr('stroke', '#F00');\n",
       "\n",
       "  non_leaf_node_without_children.append('path')\n",
       "      .attr('d', 'M0,0 C 10,0 0,-10 10,-10')\n",
       "      .attr('fill', 'none')\n",
       "      .attr('stroke-width', 1.0)\n",
       "      .attr('stroke', '#0F0');\n",
       "\n",
       "  const node_content = nodes.append('g').attr(\n",
       "      'transform',\n",
       "      `translate(0,${options.node_padding - options.node_y_size / 2})`);\n",
       "\n",
       "  node_content.append(node => create_node_element(options, node));\n",
       "}\n",
       "\n",
       "/**\n",
       " * Creates the D3 content for a single node.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {!node} node Node to draw.\n",
       " * @return {!d3} D3 content.\n",
       " */\n",
       "function create_node_element(options, node) {\n",
       "  // Output accumulator.\n",
       "  let output = {\n",
       "    // Content to draw.\n",
       "    content: d3.create('svg:g'),\n",
       "    // Vertical offset to the next element to draw.\n",
       "    vertical_offset: 0\n",
       "  };\n",
       "\n",
       "  // Conditions.\n",
       "  if (node.data.condition != null) {\n",
       "    display_condition(options, node.data.condition, output);\n",
       "  }\n",
       "\n",
       "  // Values.\n",
       "  if (node.data.value != null) {\n",
       "    display_value(options, node.data.value, output);\n",
       "  }\n",
       "\n",
       "  // Explanations.\n",
       "  if (node.data.explanation != null) {\n",
       "    display_explanation(options, node.data.explanation, output);\n",
       "  }\n",
       "\n",
       "  return output.content.node();\n",
       "}\n",
       "\n",
       "\n",
       "/**\n",
       " * Adds a single line of text inside of a node.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {string} text Text to display.\n",
       " * @param {!output} output Output display accumulator.\n",
       " */\n",
       "function display_node_text(options, text, output) {\n",
       "  output.content.append('text')\n",
       "      .attr('x', options.node_padding)\n",
       "      .attr('y', output.vertical_offset)\n",
       "      .attr('alignment-baseline', 'hanging')\n",
       "      .text(text);\n",
       "  output.vertical_offset += 10;\n",
       "}\n",
       "\n",
       "/**\n",
       " * Adds a single line of text inside of a node with a tooltip.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {string} text Text to display.\n",
       " * @param {string} tooltip Text in the Tooltip.\n",
       " * @param {!output} output Output display accumulator.\n",
       " */\n",
       "function display_node_text_with_tooltip(options, text, tooltip, output) {\n",
       "  const item = output.content.append('text')\n",
       "                   .attr('x', options.node_padding)\n",
       "                   .attr('alignment-baseline', 'hanging')\n",
       "                   .text(text);\n",
       "\n",
       "  add_tooltip(options, item, () => tooltip);\n",
       "  output.vertical_offset += 10;\n",
       "}\n",
       "\n",
       "/**\n",
       " * Adds a tooltip to a dom element.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {!dom} target Dom element to equip with a tooltip.\n",
       " * @param {!func} get_content Generates the html content of the tooltip.\n",
       " */\n",
       "function add_tooltip(options, target, get_content) {\n",
       "  function show(d) {\n",
       "    options.tooltip.style('display', 'block');\n",
       "    options.tooltip.html(get_content());\n",
       "  }\n",
       "\n",
       "  function hide(d) {\n",
       "    options.tooltip.style('display', 'none');\n",
       "  }\n",
       "\n",
       "  function move(d) {\n",
       "    options.tooltip.style('display', 'block');\n",
       "    options.tooltip.style('left', (d.pageX + 5) + 'px');\n",
       "    options.tooltip.style('top', d.pageY + 'px');\n",
       "  }\n",
       "\n",
       "  target.on('mouseover', show);\n",
       "  target.on('mouseout', hide);\n",
       "  target.on('mousemove', move);\n",
       "}\n",
       "\n",
       "/**\n",
       " * Adds a condition inside of a node.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {!condition} condition Condition to display.\n",
       " * @param {!output} output Output display accumulator.\n",
       " */\n",
       "function display_condition(options, condition, output) {\n",
       "  threshold_format = d3.format('r');\n",
       "\n",
       "  if (condition.type === 'IS_MISSING') {\n",
       "    display_node_text(options, `${condition.attribute} is missing`, output);\n",
       "    return;\n",
       "  }\n",
       "\n",
       "  if (condition.type === 'IS_TRUE') {\n",
       "    display_node_text(options, `${condition.attribute} is true`, output);\n",
       "    return;\n",
       "  }\n",
       "\n",
       "  if (condition.type === 'NUMERICAL_IS_HIGHER_THAN') {\n",
       "    format = d3.format('r');\n",
       "    display_node_text(\n",
       "        options,\n",
       "        `${condition.attribute} >= ${threshold_format(condition.threshold)}`,\n",
       "        output);\n",
       "    return;\n",
       "  }\n",
       "\n",
       "  if (condition.type === 'CATEGORICAL_IS_IN') {\n",
       "    display_node_text_with_tooltip(\n",
       "        options, `${condition.attribute} in [...]`,\n",
       "        `${condition.attribute} in [${condition.mask}]`, output);\n",
       "    return;\n",
       "  }\n",
       "\n",
       "  if (condition.type === 'CATEGORICAL_SET_CONTAINS') {\n",
       "    display_node_text_with_tooltip(\n",
       "        options, `${condition.attribute} intersect [...]`,\n",
       "        `${condition.attribute} intersect [${condition.mask}]`, output);\n",
       "    return;\n",
       "  }\n",
       "\n",
       "  if (condition.type === 'NUMERICAL_SPARSE_OBLIQUE') {\n",
       "    display_node_text_with_tooltip(\n",
       "        options, `Sparse oblique split...`,\n",
       "        `[${condition.attributes}]*[${condition.weights}]>=${\n",
       "            threshold_format(condition.threshold)}`,\n",
       "        output);\n",
       "    return;\n",
       "  }\n",
       "\n",
       "  display_node_text(\n",
       "      options, `Non supported condition ${condition.type}`, output);\n",
       "}\n",
       "\n",
       "/**\n",
       " * Adds a value inside of a node.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {!value} value Value to display.\n",
       " * @param {!output} output Output display accumulator.\n",
       " */\n",
       "function display_value(options, value, output) {\n",
       "  if (value.type === 'PROBABILITY') {\n",
       "    const left_margin = 0;\n",
       "    const right_margin = 50;\n",
       "    const plot_width = options.node_x_size - options.node_padding * 2 -\n",
       "        left_margin - right_margin;\n",
       "\n",
       "    let cusum = Array.from(d3.cumsum(value.distribution));\n",
       "    cusum.unshift(0);\n",
       "    const distribution_plot = output.content.append('g').attr(\n",
       "        'transform', `translate(0,${output.vertical_offset + 0.5})`);\n",
       "\n",
       "    distribution_plot.selectAll('rect')\n",
       "        .data(value.distribution)\n",
       "        .join('rect')\n",
       "        .attr('height', 10)\n",
       "        .attr(\n",
       "            'x',\n",
       "            (d, i) =>\n",
       "                (cusum[i] * plot_width + left_margin + options.node_padding))\n",
       "        .attr('width', (d, i) => d * plot_width)\n",
       "        .style('fill', (d, i) => d3.schemeSet1[i]);\n",
       "\n",
       "    const num_examples =\n",
       "        output.content.append('g')\n",
       "            .attr('transform', `translate(0,${output.vertical_offset})`)\n",
       "            .append('text')\n",
       "            .attr('x', options.node_x_size - options.node_padding)\n",
       "            .attr('alignment-baseline', 'hanging')\n",
       "            .attr('text-anchor', 'end')\n",
       "            .text(`(${value.num_examples})`);\n",
       "\n",
       "    const distribution_details = d3.create('ul');\n",
       "    distribution_details.selectAll('li')\n",
       "        .data(value.distribution)\n",
       "        .join('li')\n",
       "        .append('span')\n",
       "        .text(\n",
       "            (d, i) =>\n",
       "                'class ' + i + ': ' + d3.format('.3%')(value.distribution[i]));\n",
       "\n",
       "    add_tooltip(options, distribution_plot, () => distribution_details.html());\n",
       "    add_tooltip(options, num_examples, () => 'Number of examples');\n",
       "\n",
       "    output.vertical_offset += 10;\n",
       "    return;\n",
       "  }\n",
       "\n",
       "  if (value.type === 'REGRESSION') {\n",
       "    display_node_text(\n",
       "        options,\n",
       "        'value: ' + d3.format('r')(value.value) + ` (` +\n",
       "            d3.format('.6')(value.num_examples) + `)`,\n",
       "        output);\n",
       "    return;\n",
       "  }\n",
       "\n",
       "  display_node_text(options, `Non supported value ${value.type}`, output);\n",
       "}\n",
       "\n",
       "/**\n",
       " * Adds an explanation inside of a node.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {!explanation} explanation Explanation to display.\n",
       " * @param {!output} output Output display accumulator.\n",
       " */\n",
       "function display_explanation(options, explanation, output) {\n",
       "  // Margin before the explanation.\n",
       "  output.vertical_offset += 10;\n",
       "\n",
       "  display_node_text(\n",
       "      options, `Non supported explanation ${explanation.type}`, output);\n",
       "}\n",
       "\n",
       "\n",
       "/**\n",
       " * Draw the edges of the tree.\n",
       " * @param {!options} options Dictionary of configurations.\n",
       " * @param {!graph} graph D3 search handle containing the graph.\n",
       " * @param {!tree_struct} tree_struct Structure of the tree (node placement,\n",
       " *     data, etc.).\n",
       " */\n",
       "function display_edges(options, graph, tree_struct) {\n",
       "  // Draw an edge between a parent and a child node with a bezier.\n",
       "  function draw_single_edge(d) {\n",
       "    return 'M' + (d.source.y + options.node_x_size) + ',' + d.source.x + ' C' +\n",
       "        (d.source.y + options.node_x_size + options.edge_rounding) + ',' +\n",
       "        d.source.x + ' ' + (d.target.y - options.edge_rounding) + ',' +\n",
       "        d.target.x + ' ' + d.target.y + ',' + d.target.x;\n",
       "  }\n",
       "\n",
       "  graph.append('g')\n",
       "      .attr('fill', 'none')\n",
       "      .attr('stroke-width', 1.2)\n",
       "      .selectAll('path')\n",
       "      .data(tree_struct.links())\n",
       "      .join('path')\n",
       "      .attr('d', draw_single_edge)\n",
       "      .attr(\n",
       "          'stroke', d => (d.target === d.source.children[0]) ? '#0F0' : '#F00');\n",
       "}\n",
       "\n",
       "display_tree({\"margin\": 10, \"node_x_size\": 160, \"node_y_size\": 28, \"node_x_offset\": 180, \"node_y_offset\": 33, \"font_size\": 10, \"edge_rounding\": 20, \"node_padding\": 2, \"show_plot_bounding_box\": false}, {\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [0.476, 0.316, 0.208], \"num_examples\": 250.0}, \"condition\": {\"type\": \"NUMERICAL_IS_HIGHER_THAN\", \"attribute\": \"bill_length_mm\", \"threshold\": 42.349998474121094}, \"children\": [{\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [0.06569343065693431, 0.5547445255474452, 0.3795620437956204], \"num_examples\": 137.0}, \"condition\": {\"type\": \"CATEGORICAL_IS_IN\", \"attribute\": \"island\", \"mask\": [\"Dream\"]}, \"children\": [{\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [0.018867924528301886, 0.0, 0.9811320754716981], \"num_examples\": 53.0}, \"condition\": {\"type\": \"NUMERICAL_IS_HIGHER_THAN\", \"attribute\": \"bill_length_mm\", \"threshold\": 45.650001525878906}, \"children\": [{\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [0.0, 0.0, 1.0], \"num_examples\": 44.0}}, {\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [0.1111111111111111, 0.0, 0.8888888888888888], \"num_examples\": 9.0}}]}, {\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [0.09523809523809523, 0.9047619047619048, 0.0], \"num_examples\": 84.0}, \"condition\": {\"type\": \"NUMERICAL_IS_HIGHER_THAN\", \"attribute\": \"bill_depth_mm\", \"threshold\": 17.549999237060547}, \"children\": [{\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [1.0, 0.0, 0.0], \"num_examples\": 8.0}}, {\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [0.0, 1.0, 0.0], \"num_examples\": 76.0}}]}]}, {\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [0.9734513274336283, 0.02654867256637168, 0.0], \"num_examples\": 113.0}, \"condition\": {\"type\": \"NUMERICAL_IS_HIGHER_THAN\", \"attribute\": \"bill_depth_mm\", \"threshold\": 16.049999237060547}, \"children\": [{\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [1.0, 0.0, 0.0], \"num_examples\": 107.0}}, {\"value\": {\"type\": \"PROBABILITY\", \"distribution\": [0.5, 0.5, 0.0], \"num_examples\": 6.0}}]}]}, \"#tree_plot_ff8b6abb7d5142889bd482156927e9db\")\n",
       "</script>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Plot the first tree of the model\n",
    "tfdf.model_plotter.plot_model_in_colab(model_1, tree_idx=0, max_depth=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cPcL_hDnY7Zy"
   },
   "source": [
    "The root node on the left contains the first condition (`bill_depth_mm >= 16.55`), number of examples (240) and label distribution (the red-blue-green bar).\n",
    "\n",
    "Examples that evaluates true to `bill_depth_mm >= 16.55` are branched to the green path. The other ones are branched to the red path.\n",
    "\n",
    "The deeper the node, the more `pure` they become i.e. the label distribution is biased toward a subset of classes. \n",
    "\n",
    "**Note:** Over the mouse on top of the plot for details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-ob3ovQ2seVY"
   },
   "source": [
    "## Model tructure and feature importance\n",
    "\n",
    "The overall structure of the model is show with `.summary()`. You will see:\n",
    "\n",
    "-   **Type**: The learning algorithm used to train the model (`Random Forest` in\n",
    "    our case).\n",
    "-   **Task**: The problem solved by the model (`Classification` in our case).\n",
    "-   **Input Features**: The input features of the model.\n",
    "-   **Variable Importance**: Different measures of the importance of each\n",
    "    feature for the model.\n",
    "-   **Out-of-bag evaluation**: The out-of-bag evaluation of the model. This is a\n",
    "    cheap and efficient alternative to cross-validation.\n",
    "-   **Number of {trees, nodes} and other metrics**: Statistics about the\n",
    "    structure of the decisions forests.\n",
    "\n",
    "**Remark:** The summary's content depends on the learning algorithm (e.g.\n",
    "Out-of-bag is only available for Random Forest) and the hyper-parameters (e.g.\n",
    "the *mean-decrease-in-accuracy* variable importance can be disabled in the\n",
    "hyper-parameters)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "kzXME28Lq7Il"
   },
   "outputs": [
     {
      "data": {
       "application/javascript": [
        "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})"
       ],
       "text/plain": [
        "<IPython.core.display.Javascript object>"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Model: \"random_forest_model\"\n",
       "_________________________________________________________________\n",
       " Layer (type)                Output Shape              Param #   \n",
       "=================================================================\n",
       "=================================================================\n",
       "Total params: 1\n",
       "Trainable params: 0\n",
       "Non-trainable params: 1\n",
       "_________________________________________________________________\n",
       "Type: \"RANDOM_FOREST\"\n",
       "Task: CLASSIFICATION\n",
       "Label: \"__LABEL\"\n",
       "\n",
       "Input Features (7):\n",
       "\tbill_depth_mm\n",
       "\tbill_length_mm\n",
       "\tbody_mass_g\n",
       "\tflipper_length_mm\n",
       "\tisland\n",
       "\tsex\n",
       "\tyear\n",
       "\n",
       "No weights\n",
       "\n",
       "Variable Importance: MEAN_MIN_DEPTH:\n",
       "    1.            \"island\"  1.219722 ################\n",
       "    2.              \"year\"  1.219722 ################\n",
       "    3.           \"__LABEL\"  1.219722 ################\n",
       "    4.               \"sex\"  1.217500 ###############\n",
       "    5.    \"bill_length_mm\"  1.070278 ###########\n",
       "    6.       \"body_mass_g\"  0.954444 ########\n",
       "    7.     \"bill_depth_mm\"  0.766944 ##\n",
       "    8. \"flipper_length_mm\"  0.681944 \n",
       "\n",
       "Variable Importance: NUM_AS_ROOT:\n",
       "    1. \"flipper_length_mm\" 142.000000 ################\n",
       "    2.     \"bill_depth_mm\" 92.000000 #########\n",
       "    3.       \"body_mass_g\" 46.000000 ###\n",
       "    4.    \"bill_length_mm\" 20.000000 \n",
       "\n",
       "Variable Importance: NUM_NODES:\n",
       "    1. \"flipper_length_mm\" 173.000000 ################\n",
       "    2.     \"bill_depth_mm\" 130.000000 ############\n",
       "    3.       \"body_mass_g\" 62.000000 #####\n",
       "    4.    \"bill_length_mm\" 39.000000 ###\n",
       "    5.               \"sex\"  1.000000 \n",
       "\n",
       "Variable Importance: SUM_SCORE:\n",
       "    1. \"flipper_length_mm\" 5403.527420 ################\n",
       "    2.     \"bill_depth_mm\" 3187.556642 #########\n",
       "    3.       \"body_mass_g\" 1345.707064 ###\n",
       "    4.    \"bill_length_mm\" 547.815473 #\n",
       "    5.               \"sex\"  0.940019 \n",
       "\n",
       "\n",
       "\n",
       "Winner take all: true\n",
       "Out-of-bag evaluation: accuracy:1 logloss:0.0298891\n",
       "Number of trees: 300\n",
       "Total number of nodes: 1110\n",
       "\n",
       "Number of nodes by tree:\n",
       "Count: 300 Average: 3.7 StdDev: 1.12101\n",
       "Min: 3 Max: 7 Ignored: 0\n",
       "----------------------------------------------\n",
       "[ 3, 4) 208  69.33%  69.33% ##########\n",
       "[ 4, 5)   0   0.00%  69.33%\n",
       "[ 5, 6)  79  26.33%  95.67% ####\n",
       "[ 6, 7)   0   0.00%  95.67%\n",
       "[ 7, 7]  13   4.33% 100.00% #\n",
       "\n",
       "Depth by leafs:\n",
       "Count: 705 Average: 1.29929 StdDev: 0.464101\n",
       "Min: 1 Max: 3 Ignored: 0\n",
       "----------------------------------------------\n",
       "[ 1, 2) 496  70.35%  70.35% ##########\n",
       "[ 2, 3) 207  29.36%  99.72% ####\n",
       "[ 3, 3]   2   0.28% 100.00%\n",
       "\n",
       "Number of training obs by leaf:\n",
       "Count: 705 Average: 48.5106 StdDev: 45.7302\n",
       "Min: 5 Max: 109 Ignored: 0\n",
       "----------------------------------------------\n",
       "[   5,  10) 221  31.35%  31.35% ##########\n",
       "[  10,  15) 146  20.71%  52.06% #######\n",
       "[  15,  20)  37   5.25%  57.30% ##\n",
       "[  20,  26)   1   0.14%  57.45%\n",
       "[  26,  31)   0   0.00%  57.45%\n",
       "[  31,  36)   0   0.00%  57.45%\n",
       "[  36,  41)   0   0.00%  57.45%\n",
       "[  41,  47)   0   0.00%  57.45%\n",
       "[  47,  52)   0   0.00%  57.45%\n",
       "[  52,  57)   0   0.00%  57.45%\n",
       "[  57,  62)   0   0.00%  57.45%\n",
       "[  62,  68)   0   0.00%  57.45%\n",
       "[  68,  73)   0   0.00%  57.45%\n",
       "[  73,  78)   0   0.00%  57.45%\n",
       "[  78,  83)   0   0.00%  57.45%\n",
       "[  83,  89)   4   0.57%  58.01%\n",
       "[  89,  94)   7   0.99%  59.01%\n",
       "[  94,  99)  55   7.80%  66.81% ##\n",
       "[  99, 104) 125  17.73%  84.54% ######\n",
       "[ 104, 109] 109  15.46% 100.00% #####\n",
       "\n",
       "Attribute in nodes:\n",
       "\t173 : flipper_length_mm [NUMERICAL]\n",
       "\t130 : bill_depth_mm [NUMERICAL]\n",
       "\t62 : body_mass_g [NUMERICAL]\n",
       "\t39 : bill_length_mm [NUMERICAL]\n",
       "\t1 : sex [CATEGORICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 0:\n",
       "\t142 : flipper_length_mm [NUMERICAL]\n",
       "\t92 : bill_depth_mm [NUMERICAL]\n",
       "\t46 : body_mass_g [NUMERICAL]\n",
       "\t20 : bill_length_mm [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 1:\n",
       "\t173 : flipper_length_mm [NUMERICAL]\n",
       "\t130 : bill_depth_mm [NUMERICAL]\n",
       "\t62 : body_mass_g [NUMERICAL]\n",
       "\t38 : bill_length_mm [NUMERICAL]\n",
       "\t1 : sex [CATEGORICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 2:\n",
       "\t173 : flipper_length_mm [NUMERICAL]\n",
       "\t130 : bill_depth_mm [NUMERICAL]\n",
       "\t62 : body_mass_g [NUMERICAL]\n",
       "\t39 : bill_length_mm [NUMERICAL]\n",
       "\t1 : sex [CATEGORICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 3:\n",
       "\t173 : flipper_length_mm [NUMERICAL]\n",
       "\t130 : bill_depth_mm [NUMERICAL]\n",
       "\t62 : body_mass_g [NUMERICAL]\n",
       "\t39 : bill_length_mm [NUMERICAL]\n",
       "\t1 : sex [CATEGORICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 5:\n",
       "\t173 : flipper_length_mm [NUMERICAL]\n",
       "\t130 : bill_depth_mm [NUMERICAL]\n",
       "\t62 : body_mass_g [NUMERICAL]\n",
       "\t39 : bill_length_mm [NUMERICAL]\n",
       "\t1 : sex [CATEGORICAL]\n",
       "\n",
       "Condition type in nodes:\n",
       "\t404 : HigherCondition\n",
       "\t1 : ContainsBitmapCondition\n",
       "Condition type in nodes with depth <= 0:\n",
       "\t300 : HigherCondition\n",
       "Condition type in nodes with depth <= 1:\n",
       "\t403 : HigherCondition\n",
       "\t1 : ContainsBitmapCondition\n",
       "Condition type in nodes with depth <= 2:\n",
       "\t404 : HigherCondition\n",
       "\t1 : ContainsBitmapCondition\n",
       "Condition type in nodes with depth <= 3:\n",
       "\t404 : HigherCondition\n",
       "\t1 : ContainsBitmapCondition\n",
       "Condition type in nodes with depth <= 5:\n",
       "\t404 : HigherCondition\n",
       "\t1 : ContainsBitmapCondition\n",
       "Node format: NOT_SET\n",
       "\n",
       "Training OOB:\n",
       "\ttrees: 1, Out-of-bag evaluation: accuracy:0.97619 logloss:0.858182\n",
       "\ttrees: 11, Out-of-bag evaluation: accuracy:1 logloss:0.0179309\n",
       "\ttrees: 21, Out-of-bag evaluation: accuracy:1 logloss:0.0184216\n",
       "\ttrees: 31, Out-of-bag evaluation: accuracy:0.991228 logloss:0.0255191\n",
       "\ttrees: 41, Out-of-bag evaluation: accuracy:1 logloss:0.0289391\n",
       "\ttrees: 51, Out-of-bag evaluation: accuracy:1 logloss:0.0250122\n",
       "\ttrees: 61, Out-of-bag evaluation: accuracy:1 logloss:0.026478\n",
       "\ttrees: 71, Out-of-bag evaluation: accuracy:0.991228 logloss:0.0265787\n",
       "\ttrees: 81, Out-of-bag evaluation: accuracy:1 logloss:0.027323\n",
       "\ttrees: 91, Out-of-bag evaluation: accuracy:1 logloss:0.0276958\n",
       "\ttrees: 101, Out-of-bag evaluation: accuracy:0.991228 logloss:0.0302609\n",
       "\ttrees: 111, Out-of-bag evaluation: accuracy:0.991228 logloss:0.0302445\n",
       "\ttrees: 121, Out-of-bag evaluation: accuracy:0.991228 logloss:0.0305131\n",
       "\ttrees: 131, Out-of-bag evaluation: accuracy:1 logloss:0.031088\n",
       "\ttrees: 141, Out-of-bag evaluation: accuracy:1 logloss:0.0314104\n",
       "\ttrees: 151, Out-of-bag evaluation: accuracy:1 logloss:0.0310795\n",
       "\ttrees: 161, Out-of-bag evaluation: accuracy:1 logloss:0.0312522\n",
       "\ttrees: 171, Out-of-bag evaluation: accuracy:1 logloss:0.0318305\n",
       "\ttrees: 181, Out-of-bag evaluation: accuracy:1 logloss:0.0319385\n",
       "\ttrees: 191, Out-of-bag evaluation: accuracy:1 logloss:0.0310959\n",
       "\ttrees: 201, Out-of-bag evaluation: accuracy:1 logloss:0.0316298\n",
       "\ttrees: 211, Out-of-bag evaluation: accuracy:1 logloss:0.0311654\n",
       "\ttrees: 221, Out-of-bag evaluation: accuracy:1 logloss:0.0314555\n",
       "\ttrees: 231, Out-of-bag evaluation: accuracy:1 logloss:0.0315575\n",
       "\ttrees: 241, Out-of-bag evaluation: accuracy:1 logloss:0.0314899\n",
       "\ttrees: 251, Out-of-bag evaluation: accuracy:1 logloss:0.031193\n",
       "\ttrees: 261, Out-of-bag evaluation: accuracy:1 logloss:0.031569\n",
       "\ttrees: 271, Out-of-bag evaluation: accuracy:1 logloss:0.0309735\n",
       "\ttrees: 281, Out-of-bag evaluation: accuracy:1 logloss:0.0306458\n",
       "\ttrees: 291, Out-of-bag evaluation: accuracy:1 logloss:0.0299459\n",
       "\ttrees: 300, Out-of-bag evaluation: accuracy:1 logloss:0.0298891\n",
       "\n"
      ]
     }
   ],
   "source": [
    "# Print the overall structure of the model\n",
    "%set_cell_height 300\n",
    "model_1.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d4ApRpUm02zU"
   },
   "source": [
    "The information in ``summary`` are all available programatically using the model inspector:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "G3xuB3jN1Cww"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[\"bill_depth_mm\" (1; #0),\n",
       " \"bill_length_mm\" (1; #1),\n",
       " \"body_mass_g\" (1; #2),\n",
       " \"flipper_length_mm\" (1; #3),\n",
       " \"island\" (4; #4),\n",
       " \"sex\" (4; #5),\n",
       " \"year\" (1; #6)]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# The input features\n",
    "model_1.make_inspector().features()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "BZ2RBbU51L6s"
   },
   "outputs": [
     {
      "data": {
       "text/plain": [
        "{'MEAN_MIN_DEPTH': [(\"island\" (4; #4), 1.2197222222222215),\n",
        "  (\"year\" (1; #6), 1.2197222222222215),\n",
        "  (\"__LABEL\" (4; #7), 1.2197222222222215),\n",
        "  (\"sex\" (4; #5), 1.2174999999999994),\n",
        "  (\"bill_length_mm\" (1; #1), 1.0702777777777757),\n",
        "  (\"body_mass_g\" (1; #2), 0.9544444444444413),\n",
        "  (\"bill_depth_mm\" (1; #0), 0.7669444444444432),\n",
        "  (\"flipper_length_mm\" (1; #3), 0.6819444444444437)],\n",
        " 'NUM_NODES': [(\"flipper_length_mm\" (1; #3), 173.0),\n",
        "  (\"bill_depth_mm\" (1; #0), 130.0),\n",
        "  (\"body_mass_g\" (1; #2), 62.0),\n",
        "  (\"bill_length_mm\" (1; #1), 39.0),\n",
        "  (\"sex\" (4; #5), 1.0)],\n",
        " 'SUM_SCORE': [(\"flipper_length_mm\" (1; #3), 5403.527419958264),\n",
        "  (\"bill_depth_mm\" (1; #0), 3187.556642279029),\n",
        "  (\"body_mass_g\" (1; #2), 1345.7070640614256),\n",
        "  (\"bill_length_mm\" (1; #1), 547.8154730461538),\n",
        "  (\"sex\" (4; #5), 0.9400193095207214)],\n",
        " 'NUM_AS_ROOT': [(\"flipper_length_mm\" (1; #3), 142.0),\n",
        "  (\"bill_depth_mm\" (1; #0), 92.0),\n",
        "  (\"body_mass_g\" (1; #2), 46.0),\n",
        "  (\"bill_length_mm\" (1; #1), 20.0)]}"
       ]
      },
      "execution_count": 16,
      "metadata": {},
      "output_type": "execute_result"
     }
   ],
   "source": [
    "# The feature importances\n",
    "model_1.make_inspector().variable_importances()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0zvyRJVk1aEk"
   },
   "source": [
    "The content of the summary and the inspector depends on the learning algorithm (`tfdf.keras.RandomForestModel` in this case) and its hyper-parameters (e.g. `compute_oob_variable_importances=True` will trigger the computation of Out-of-bag variable importances for the Random Forest learner)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tFVmrHtWXYKY"
   },
   "source": [
    "## Model Self Evaluation\n",
    "\n",
    "During training TFDF models can self evaluate even if no validation dataset is provided to the `fit()` method. The exact logic depends on the model. For example, Random Forest will use Out-of-bag evaluation while Gradient Boosted Trees will use internal train-validation.\n",
    "\n",
    "**Note:** While this evaluation is  computed during training, it is NOT computed on the training dataset and can be used as a low quality evaluation.\n",
    "\n",
    "The model self evaluation is available with the inspector's `evaluation()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "BZPzyIMmYmsI"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Evaluation(num_examples=114, accuracy=1.0, loss=0.029889076245589216, rmse=None, ndcg=None, aucs=None)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# TODO\n",
    "# Evaluate the model\n",
    "model_1.make_inspector().evaluation()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vBSz-jE0Qss_"
   },
   "source": [
    "## Plotting the training logs\n",
    "\n",
    "The training logs show the quality of the model (e.g. accuracy evaluated on the out-of-bag or validation dataset) according to the number of trees in the model. These logs are helpful to study the balance between model size and model quality.\n",
    "\n",
    "The logs are available in multiple ways:\n",
    "\n",
    "1. Displayed in during training if `fit()` is wrapped in `with sys_pipes():` (see example above).\n",
    "1. At the end of the model summary i.e. `model.summary()` (see example above).\n",
    "1. Programmatically, using the model inspector i.e. `model.make_inspector().training_logs()`.\n",
    "1. Using [TensorBoard](https://www.tensorflow.org/tensorboard)\n",
    "\n",
    "Let's try the options 2 and 3:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "ZbRk7xvpTKQG"
   },
   "outputs": [
     {
      "data": {
       "application/javascript": [
        "google.colab.output.setIframeHeight(0, true, {maxHeight: 150})"
       ],
       "text/plain": [
        "<IPython.core.display.Javascript object>"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "data": {
       "text/plain": [
        "[TrainLog(num_trees=1, evaluation=Evaluation(num_examples=42, accuracy=0.9761904761904762, loss=0.8581821804954892, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=11, evaluation=Evaluation(num_examples=113, accuracy=1.0, loss=0.017930947573839034, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=21, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.01842158549187476, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=31, evaluation=Evaluation(num_examples=114, accuracy=0.9912280701754386, loss=0.025519077061561115, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=41, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.028939139326674898, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=51, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.02501224357177291, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=61, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.026478012468208346, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=71, evaluation=Evaluation(num_examples=114, accuracy=0.9912280701754386, loss=0.026578680693841818, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=81, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.027322964514033835, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=91, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.027695828329837115, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=101, evaluation=Evaluation(num_examples=114, accuracy=0.9912280701754386, loss=0.030260881131286162, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=111, evaluation=Evaluation(num_examples=114, accuracy=0.9912280701754386, loss=0.030244526665723116, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=121, evaluation=Evaluation(num_examples=114, accuracy=0.9912280701754386, loss=0.030513084019746697, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=131, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.03108797009968967, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=141, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.031410421081410164, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=151, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.031079513429288278, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=161, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.03125219445740968, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=171, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.03183049037072219, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=181, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.0319384854417621, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=191, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.03109589420062931, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=201, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.03162980020830506, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=211, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.031165373911917732, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=221, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.0314555442892015, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=231, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.0315574641061718, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=241, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.03148993304031983, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=251, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.03119300126931385, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=261, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.03156904316621653, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=271, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.030973476406775023, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=281, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.030645840090552445, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=291, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.02994587737249962, rmse=None, ndcg=None, aucs=None)),\n",
        " TrainLog(num_trees=300, evaluation=Evaluation(num_examples=114, accuracy=1.0, loss=0.029889076245589216, rmse=None, ndcg=None, aucs=None))]"
       ]
      },
      "execution_count": 18,
      "metadata": {},
      "output_type": "execute_result"
     }
   ],
   "source": [
    "%set_cell_height 150\n",
    "model_1.make_inspector().training_logs()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WynFJCEbhuF_"
   },
   "source": [
    "Let's plot it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "id": "xzPH7Gggh0g1"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAt0AAAEGCAYAAAC5JimDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABD8ElEQVR4nO3de5xkdX3n/9e7q7t6ppq59MAwAtMImkFEooQgMXEl3qJAVCKJG0i8hKBIIkbNZYPJJpq4vxU1aowxElAScbMQ10tEQ0TComSTKAw63AR0BHQaBmZgBmbo6unuqvr8/jjn9NQ0famuruo6VfN+Ph796KpzTlV9upo5fPpTn/P5KiIwMzMzM7P26et0AGZmZmZmvc5Jt5mZmZlZmznpNjMzMzNrMyfdZmZmZmZt5qTbzMzMzKzN+jsdwHI47LDD4phjjul0GGZmi3brrbc+GhHrOx3HcvI528y62Vzn7YMi6T7mmGPYvHlzp8MwM1s0ST/qdAzLzedsM+tmc5233V5iZmZmZtZmTrrNzMzMzNrMSbeZmZmZWZs56TYzMzMzazMn3WZmZmZmbda2pFvSFZJ2SLpzjv2S9FeStkq6XdLJdftOl3Rvuu/iuu3rJF0v6Qfp9+F2xW9mZmZm1irtrHT/PXD6PPvPADalXxcAnwSQVAA+ke4/AThX0gnpYy4GboiITcAN6X0zMzMzs1xr25zuiLhJ0jHzHHIWcGVEBPAtSWslHQEcA2yNiPsAJF2dHvu99PuL08d/BvgG8IftiL8R3/3xbm68Z8eCx20cLvFfnz+y5Nf76u0P8f2H9y75eQCOXLuSc049esnPc+0d27ln+54Fj3v+set40ab8rO9x54NP8PW7Hu50GHaQ+dVTj+aotSs7HUZP+vyto0xVa5zbgvOamVk7dHJxnKOAbXX3R9Nts23/mfT2hojYDhAR2yUdPteTS7qApILO0Ue35yT80X/9ATd9fyfS3MdEJN9ffsIG1g0Vm36tai343X+8jclqbd7Xa0QW00uffTiHr1rR9PPUasG7/nELE5X5Y4qAYw4t8Y0/eEnTr9VqH73++9xwz44lv5dmi3HaceuddLfJNbc9xJ7xKSfdZpZbnUy6Z0t3Yp7tixIRlwGXAZxyyimLfnwjntw3xYs2HcZnz/+ZOY+5/nuP8JYrN7NtV3lJSffDe/YxWa3x/rN/csn/U7nxnh2c9/e3sG3X+JKS7p1PTjBRqfG+XzqRN7zg6XMe94Gv3cOn/u0+qrWg0JePLPfHu8q88jkb+Ns3nNLpUMysBYaKBR5+YrzTYZiZzamT00tGgfqei43AQ/NsB3gkbUEh/b5wb0cblSerlIqFeY/ZOJxUtUZ3L+1/BqO7ygc831Lsj6m8tJh2NxbTxuGVTFWDR/bsW9LrtUpEMLp7nI3DpU6HYmYtUir2MzZR7XQYZmZz6mTSfQ3wxnSKyQuAJ9LWkVuATZKOlVQEzkmPzR7zpvT2m4AvL3fQ9cYmKwwV5/+wIEtIty0xwd2WJu0jLUgUs2RzqX8IbNvVWEzZ/m27lvYetMpjY5OMT1UZacEfMGaWD0ODBcqTlU6HYWY2p7a1l0i6iuSix8MkjQLvAQYAIuJS4FrgTGArUAbOS/dVJF0EXAcUgCsi4q70aS8BPifpfODHwOvaFX8jyhNVSoPzV7pXrRhgbWmgJVVlCY5Y23w7SGZlscBhhxSXtdKdHD/O3I04yyf7Y8OVbrPeUSr2MzbpSreZ5Vc7p5ecu8D+AN42x75rSZLymdsfA17WkgBboJFKNySV3qwq3Kxtu8Z52uoVDPbPn+Q3amOLYlq/apAVA/PHdFSLqv2tklXcR9Y56TbrFUPFApOVGlPVGgMFr/tmZvnjM1OTqrVg31SNlQv0dENS6W1FVbkV/dyZlsT0eGMxDfYX2LB6cMntLK2yv9Lt9hKzXlEaTAogZVe7zSynnHQ3KesdbKjSva7E6O5xIpofojK6e7wl/dz1MT34+DjVWvMxbdvVeExJtT8nle7dySSZocFODu8xs1YaSgsg7us2s7xy0t2krJqyUE83JBXViUqNnU9ONPVaU9Ua258Yb3mle6oa7Njb3ESRai146PHGY0oq6/mpdLvKbdZbskq3J5iYWV456W7S2MQiKt3T0zuaSzq3P76PWsDGFvYgLzWmh/fso1KLhvuiR9aV2P7EOFPVWlOv10qju8ot/dTAzDrPlW4zyzsn3U2arnQ32NMNzc/Fzi5AbHV7CTQ/xm/6YsRFtJfUAh5+orOzumu1cKXbrAeViq50m1m+Oelu0nSlu4G+4KXOxW50NN9iHLl2BdJSYlrcxYjT88o73Ne988kJJqu1ln5qYHYwk3S6pHslbZV08Sz710j6iqTbJN0l6bx2xDE06Eq3meWbk+4mLabSnc3Fbr6qPE6hTxyxZukzujOD/QU2rFrR9Bi/bbuSueFHrm0s6c4q653u697WwpU9zQ52kgrAJ4AzgBOAcyWdMOOwtwHfi4jnkazd8OF04bOWmq50e3qJmeWUk+4mjU02XumGpNq9lEr3EWtW0N/i2bNLGRs4ujuZG17sbyymp61ZQZ86P6t7tIUre5oZpwJbI+K+iJgErgbOmnFMAKskCTgE2AW0vBw9XemecKXbzPLJSXeTyhONV7ohqfQ2XVVu8bjAzMi65hfI2bZ7cRcjDhT6OGJN5yeYuNJt1lJHAdvq7o+m2+r9NfBs4CHgDuAdEfGUK6olXSBps6TNO3fuXHQgrnSbWd456W7S2CLmdEOS5D3U5FzsVi+MUx/Tw3v2UWliosiDTVyMuHF4Zcd7ukd3N7aKppk1RLNsm3mSeyWwBTgSOAn4a0mrn/KgiMsi4pSIOGX9+vWLDiQrgLjSbWZ55aS7SYuZ0w1JO8NUNXhkz+Kmd+ybqvLInom2LFk+MlyiWgu2L3KiyPTc8EXGtJRqf6skFXpXuc1aZBQYqbu/kaSiXe884IuR2ArcDxzf6kAGCn0U+/tc6Taz3HLS3aSxiQr9faLYYJ/1/rGBi2uveOjx9i1ZPj1RZJGJ8PTc8CYq3Y/smWCi0rn/KSbjAt3PbdYitwCbJB2bXhx5DnDNjGN+DLwMQNIG4FnAfe0IZqhY8PQSM8stJ91NKk9WKRULJNcGLazZudjbsgv/2lHpziaKLLKvu9m54dnxD3aorztbRXNknSvdZq0QERXgIuA64G7gcxFxl6QLJV2YHvY+4Ock3QHcAPxhRDzajnhKxX7P6Taz3GqsIdmeYmyi0vDkEmh+LnY7ZnRnsokii51g0mxM9dX+Z6w/ZFGPbYVsFU1Xus1aJyKuBa6dse3SutsPAa9YjliGBl3pNrP8cqW7SVmlu1HNzsXetmucgYLYsKp1M7oz2USRbYv8Q6DZueHT1f4O9XUvdhVNM+supWK/e7rNLLecdDdpbHJxlW5obi726O4yR61dSV9fY20si9VsTM3MDd+wegUDBXVsbOBiV9E0s+4yNFjw9BIzyy0n3U0qTyyu0g3NzcXetnu8Lf3cmaZjaqJaXOgTR67t3NjAxa6iaWbdxZVuM8szJ91NGpusNDyjO7NxeCXbnxhnahFzsUd3tWdGd31Mj+zdt6iJItuWEFNSWe9cpXsxq2iaWXfx9BIzyzNnH00qT1YpLbK9ZGS4RC3g4QbnYpcnKzw2NtnWC/9GhktEwEOPNxbTvqkqO/Y2Pzd8ZLjU9NLzS7XYVTTNrLuUBj29xMzyy0l3k8YmKgwtsr1kei52g+0Vo20cF5hZ7CjDBx/PYmqu0j2yrsSjT04y3oGPgNv9qYGZdZYr3WaWZ066m5RML1lkpTubi91ge0U7xwVmFrtoz/6LEZv7Q2D/6y1vtXuyUuPhPfsWvYqmmXWPUrGf8mSVWm3mSvRmZp3npLsJEZFOL1lcpTubi93oyLzsAsd2tkRkE0Uaj2lpY/eyZH25+7q3PzHe1CqaZtY9snPy+JRbTMwsf5x0N2HfVI0IFl3pzuZiL6bSvWKgj8MOKTYTZkOyiSKLqXQXC30cvmqwqdcbaXLp+aWabtVxT7dZz8rOyWNuMTGzHHLS3YTshL7YSjckldZG+6e37Rpn43Cp4aXmmzUyXGo8pt1ljhpufm74+lWDDPb3LXulO/v5XOk2613ZObnsiynNLIecdDchO6EvttINSV93w1Xlx5fnwr/FjPEb3T2+pJgkcdQi/vBoldHdza2iaWbdw5VuM8szJ91NmK50L3J6CSxuLva2Xc0tQrNYyUSRiYYmiiQTQJYWUzI2cJkr3U2uomlm3SNbO6HsBXLMLIecgTQhG0m12Dnd0Phc7D37pnhifGrZKt0ADz4+f/V5/9zwpcW0cXhlR3q63c9t1ttKaXvJmJeCN7McctLdhGzxhWYr3bDwXOzRXe2f0b0/pmxW9/zV51bNDR9ZV+Lx8hR7900t6XkWYymraJpZd3Cl28zyzEl3E7ITerM93bDwyLzlmNE9HVODs7NbFdNiZ4Mv1VJX0TSz7lAqutJtZvnlpLsJ5SVML2l0Lva2ZRxxl00U2bZAEtyqueEjyzyrO1tF05Vus942NOhKt5nll5PuJowtodLd6Fzs0d1lhooF1pYGmopxMbKJIo1UulsxN7zRFptWaVVbjJnl23Sl29NLzCyHnHQ3oTzRfKUbGpvVvW3XOCPr2j+jO5PM6l640t2KueHrhoqUioVlq3R7RrfZwWGwv49Cnzyn28xyyUl3E8Ymq0iwor+5pDsZmbdwVXk5k8RGJopsa1FMkpZ1gsno7nEGCmLDKs/oNutlkigVC650m1kutTXplnS6pHslbZV08Sz7hyV9SdLtkm6WdGLdvndIulPSXZLeWbf9vZIelLQl/TqznT/DbMoTFUoDhaZXZdw4vJJHn5yccy52RKSL0CxfO0QjE0VaOXZvOWd1b9td5qi1za+iaWbdY6jY70q3meVS25JuSQXgE8AZwAnAuZJOmHHYHwFbIuK5wBuBj6WPPRF4C3Aq8DzgVZI21T3uoxFxUvp1bbt+hrmMTVabmtGd2T/BZPZK7+PlKZ6cqCxrD/JCFzdmc8NH1rWm+j6yrsTorjIR0ZLnm8/orrL7uc0OEqVigfKUk24zy592VrpPBbZGxH0RMQlcDZw145gTgBsAIuIe4BhJG4BnA9+KiHJEVIBvAq9tY6yLUp6sNDWjO7NxgQQ3277c7SXzxrQri6k1yevG4ZXsnaiwZ7z9HwMv96cGZtY5pcHC9HU3ZmZ50s6k+yhgW9390XRbvduAswEknQo8HdgI3AmcJulQSSXgTGCk7nEXpS0pV0ganu3FJV0gabOkzTt37mzNT5Qam6g2Nbkkk83FnqunOdu+nCsoZpXguS7wbHVM0wvytLmve2yiNatomll3KBX73dNtZrnUzqR7tgbamb0ElwDDkrYAbwe+C1Qi4m7gA8D1wNdIkvPsLPpJ4JnAScB24MOzvXhEXBYRp0TEKevXr1/aTzJDebLS9OQS2D8Xe+5Kdzpto0WtHI0YLg3MO1Gk1dX35RobmM3odnuJ2cFhqFjwnG4zy6V2Jt2jHFid3gg8VH9AROyJiPMi4iSSnu71wP3pvk9HxMkRcRqwC/hBuv2RiKhGRA24nKSNZVmNTS6t0p3NxZ6zqrxrnDUrB1i9ov0zuutjGhkuzV1931XmkMH+ls0Nb3RlzqXyuECzg0tpsN8rUppZLrUz6b4F2CTpWElF4BzgmvoDJK1N9wG8GbgpIvak+w5Pvx9N0oJyVXr/iLqneC1JK8qyKk8srdIN80/vWO5xgZmNw3Mv2pP0Ra9s2dzwNSsHWLWiv+3tJaPLuLKnmXWeK91mllfNl2sXEBEVSRcB1wEF4IqIuEvShen+S0kumLxSUhX4HnB+3VN8QdKhwBTwtojYnW7/oKSTSFpVHgDe2q6fYS7lJVa6IUlwbxt9fNZ923aP8xPrD1nS8zdjZF2Jm+/fRUQ8JblO/hBobeK6HGMDt+1qzSqaZtYdSkVXus0snxbMHCX1kYztOxIYB+6KiEcaefJ0nN+1M7ZdWnf7P4FNMx+X7nvRHNvf0Mhrt9PYEqeXwIFzsVfVtZEkM7rLvPi41vahN6J+osia0syYxnnBMw5t+evd/+hYS59zpmxyyXKt7GlmnTU0mFS6ZysemJl10pztJZKeKekyYCvJBY/nAr8NXC/pW5LOSxPyg055YmlzumHuEX2PPjnJvqlaRy78m2uiSLvmho+sSyrd7ZzV3apVNM2sO5SK/VRqwWS11ulQzMwOMF/S/D+A/wU8MyJeGRGvj4hfSReyeQ2wBuh41Xm5TVZqTFZrS690zzGre3pySYd6uutj2B9Te+aGbxxeyfhUlcfGJlv6vPVauYqmmeVfdm72qpRmljdzlmsj4tx59u0A/rIdAeVdtnR7K3q64akj87bt7tyIu/2zug/8Q6Bdc8Pr//A47JDBlj437F9F05Vus4NH9ink2GSF4SFfy2Fm+dFIT/fZs2x+ArgjTb4PKtmiC0udXrJuqEipWHhKK0eWhB+1dvkTxbkmikyP3Wvx3PDs+bbtKnPSyNqWPjfsX0XTM7rNDh5DaUHEE0zMLG8aKdeeD/wscGN6/8XAt4DjJP15RHy2TbHlUjlNupda6ZY064i+0d3jHDpUZGiJPePNmm2iyOju9swN3zhHi02rbOtgq45Zt5C0AngV8CL2XzB/J/DPEXFXJ2NrRiktiHiCiZnlTSOZXQ14djaxRNIGklUhfwa4CTioku6xtE9wqZVuSBLcme0lnZrRnZltoki7LkY8ZLCf4dJA22Z1e0a32fwkvRd4NfAN4NvADmAFcBxwSZqQ/15E3N6pGBfLlW4zy6tGku5jZowI3AEcFxG7JE21Ka7cGmtRpRuSBHfmXOzR3eOccOTqJT93s0bWlfi3Hzz6lJjaNTc8m2DSDtt2lRkqFlq2iqZZD7olIt47x76PpIuUHb2M8SxZqehKt5nlUyMj//5N0lclvUnSm4AvAzdJGgIeb2t0OZRdET/UgqR7ZF2JvRMVnhhP/nap1YIHOzxtY2TGRJFsbvhIi/u5979eidFd7ap0lxlZ5xndZnOJiH9eYP+OiNi8XPG0Qtaa50q3meVNI5nj24BfBl4ICLgS+EIkw5Vf0sbYcmm60t2C9pL6Wd1rS0V27J1gslrrcHvJgRNFsrnhrV6Ncv/rreT6ux+hVgv6+lqbHGdL15vZ/CR9hWSV33pPAJuBv42IfcsfVXOykYHZudrMLC8WrHRH4vMR8a6IeGd6u32rmeRcVj1pRaV7ejGatNI7PZqvg9M29o8NnBlTe5LXjetKTFZq7HxyoqXPGxFs29X6pevNetR9wJPA5enXHuARkt7uyzsY16JlIwM9p9vM8mbBpFvSCyTdIulJSZOSqpL2LEdweZT1Cbai0j1zgZxOLoyTmblS5v6FcdpX6U5ep7UtJo+XpxibrLrSbdaYn4qIX4uIr6RfrwdOjYi3ASfP90BJp0u6V9JWSRfPccyLJW2RdJekb7bjB8isHHCl28zyqZGe7r8mWQL+B8BK4M3Ax9sZVJ5lle7SwNKT7jWlA+diZ4vSdGJGd2ZosJ91Q8W6mNr7h8DI8OwL8ixVHj41MOsi6yVNXzCZ3j4svTvnkrGSCsAngDOAE4BzJZ0w45i1wN8Ar4mI5wCva23oByr0iZUDBfd0m1nuNNQjERFbJRUiogr8naT/aHNcuTU2WWGwv4/+QiN/ryxsY91c7NHdZQ5fNciKFiT0S4tp5QGV7kOHii2Z1jLXayWv09pKd7uWrjfrUb8H/D9JPyS5dudY4LfTC+Y/M8/jTgW2RsR9AJKuBs4Cvld3zK8BX4yIH8P0isZtNTRY8PQSM8udRjKpsqQisEXSB4HtwFB7w8qv8kS1pQvXjNTNxd62azwXldmR4RJ3b086iEZ3l9nYxphWDBRYv2qw9ZXu6Qp9599Ps7yLiGslbQKOJ0m676m7ePIv53noUcC2uvujJGs41DsOGJD0DWAV8LGIuHLmE0m6ALgA4OijlzalsFTsd6XbzHKnkXLtG9LjLgLGgBGSaSYHpbHJyvQc2FbIKt0RwejjnV0YZ39MKxl9fJxaLZZlAkjyeq2vdK9e0c+alZ7RbdagTcCzgOcC/1XSGxt4zGwjh2ZeaN8P/DTwi8ArgT+RdNxTHhRxWUScEhGnrF+/fnGRz1AqutJtZvnTyPSSH5GsSnkM8EXg4ojY2ua4cqs8UW1p0j2yLpmLvXPvBA89vi8XqydmE0Ue2btvWeaGJytztr6nOw+fGph1A0nvIblW5+Mko2A/CLymgYeOkhRiMhuBh2Y55msRMRYRj5KsZPy8JQc9j6FBV7rNLH8amV7yi8APgb8iuahyq6Qz2h1YXiWV7ta1l2TtD7c8sJtqLXJT6Qb4zo8eX5a54RuHV/LQ4+NUa62bROkZ3WaL8ivAy4CHI+I8kqR4sIHH3QJsknRs2oZ4DnDNjGO+DLxIUr+kEkn7yd2tC/2pSsWCp5eYWe40kj1+GHhJVt2W9Ezgn4F/aWdgeVWerDLUgnGBmWz+9X/e92h6v/PV2ayyvVwxjawrUakFD+/Z15LJLdkqmi8+bmkfUZsdRMYjoiapImk1sAN4xkIPioiKpIuA64ACcEVE3CXpwnT/pRFxt6SvAbeTfGr6qYi4s30/SrKOwo49rZ39b2a2VI0k3TtmtJPcR3JCPiiNTVRYN9S6JDSrdP/HDx9L73e+OpvFsFwxTU8w2VVuSdKdraKZhz9gzLrE5nS03+XArSQL5dzcyAMj4lrg2hnbLp1x/0PAh1oSaQNKg650m1n+zJl0Szo7vXmXpGuBz5FcIPM6ko8UD0rlyer0MsOtcMhgP8OlAe7bOUaf4Ig1nU+6s4ki9+1Mpqq0e2749Kzu3eNPGXvQjG05WGTIrJtExG+nNy9Nq9KrI+L2Tsa0FEOeXmJmOTRfpfvVdbcfAX4+vb0TGG5bRDlXnqxMLzPcKiPrSuwuP8HTVq+g2N+a+d9LtXF4JTv3TizL3PAj1q5A2j/mb6myGd2udJs1Li20/BeS4sr/I2kH6Uolz+k2sxyaM3tML6axGcYmWlvphiTBvX30ibbOw16skeES3/3x48uSuA72F3ja6hXTyfJSZcl7J1f2NOsmkv4G+AngqnTTWyW9PF0GvusMFfuZqNSoVGstW8jMzGypFlWylfSdiDi5XcHkXbUWjE9VW746Y9ZekYdxgZnsAs+RZWrRGBkuTbeFLNXo7jKHDhVbuoiRWY/7eeDEiAgASZ8B7uhsSM3LxrqWp6qsdtJtZjmx2LPRbAshHDTGp5IewVZOL4H9vcd56kHOLvBcrhUdNw6v5MEWVbo9LtBs0e4F6peBHKGL20uyP7jLE+7rNrP8mO9CyndExMckvTAi/j3d/M/LFFculdMewVZXurO2kjz1IE9X39ctT/K6cV2Jf9ryIG/97OYlP9eWHz/Oac/yuECzhUj6CkkP9xrgbknZxJJTgf/oWGBLlFW6PcHEzPJkvuzxPOBjJCuUnQwQEf99OYLKq7HJ9lS6Tx4Z5mXHH84Lf+LQlj7vUjxvZA0vf/bhvPAnDluW13vJs9bzf+95hB89tvQWk6OGV3LmiUe0ICqznvcXnQ6gHYaKrnSbWf7Ml3TfLekBYL2k+o8ZBUREPLetkeXQWJsq3WtKA3z6N57f0udcqlUrBvjUm5Yvpp86epivvv1Fy/Z6ZgYR8c2Z2yS9KiK+2ol4WqU06Eq3meXPfNNLzpX0NJKVxl6zfCHlVzb3dajFSbeZWY78OdDVSfd0pdtJt5nlyLzZY0Q8DDxPUhE4Lt18b0RMtT2yHMqqJqUWt5eYmeVI118wn7UAjrm9xMxyZMHpJZJ+HvgB8Angb4DvSzqt3YHlUdYf6Eq3mfUCSTek3z9Qt/mtHQqnZUqudJtZDjWSPX4EeEVE3Asg6TiSBRR+up2B5dF0pbvFi+OYmXXIEWlh5TWSriapclckZRfPf6ej0TUpK4y40m1medJI0j2QJdwAEfF9SQNtjCm3spGBXnTFzHrEnwIXAxtJCiz1AnjpskfUAiuzxXFc6TazHGlkcZzNkj4t6cXp1+XArY08uaTTJd0raauki2fZPyzpS5Jul3SzpBPr9r1D0p2S7pL0zrrt6yRdL+kH6ffhRmJphWxkoCvdZtYLIuLzEXEG8MGIeMmMr65MuAGK/X0UC33T52wzszxoJOn+LeAu4HeAdwDfAy5c6EGSCiR94GcAJwDnSjphxmF/BGxJxw++kWQuOGny/RaSBRqeB7xK0qb0MRcDN0TEJuCG9P6yKE9WKPSJwX4vK2xmvSMi3ifpNZL+Iv16VadjWqrSYGH600kzszxYMHuMiImI+EhEnA38dkR8NCImGnjuU4GtEXFfREwCVwNnzTjmBJLEmYi4BzhG0gbg2cC3IqIcERXgm8Br08ecBXwmvf0Z4JcaiKUlxiaqlIoFpK6/uN/MbJqk97O/qPI94B3ptq41VOx3pdvMcmWxJdvFLAN/FLCt7v5ouq3ebcDZAJJOBZ5O0lt4J3CapEMllYAzgZH0MRsiYjtA+v3w2V5c0gWSNkvavHPnzkWEPbfyZMWTS8ysF/0i8AsRcUVEXAGcnm7rWqViwT3dZpYri026F1Pine3YmHH/EmBY0hbg7cB3gUpE3A18ALge+BpJcr6os2dEXBYRp0TEKevXr1/MQ+c0Nln1jG4z61Vr626v6VQQrVIa7Pf0EjPLlcWWbS9fxLGj7K9OQ1LBfqj+gIjYA5wHoKRn4/70i4j4NPDpdN//TJ8P4BFJR0TEdklHADsW+TM0rTzhSreZ9aT3A9+VdCNJweQ04N2dDWlphlzpNrOcaWRxnM9mtyPib2Zum8ctwCZJx6YrWp4DXDPjudem+wDeDNyUJuJIOjz9fjRJC8pV6XHXAG9Kb78J+HIDsbTE2GTVk0vMrOdExFXAC4Avpl8/GxFXdzaqpSkVXek2s3xppGz7nPo76VSSBRfGiYiKpIuA64ACcEVE3CXpwnT/pSQXTF4pqUpy8c75dU/xBUmHAlPA2yJid7r9EuBzks4Hfgy8roGfoSXKkxUOX7ViuV7OzGzZpNfIXCPpgoh4uNPxLNXQoCvdZpYvcybdkt5NMtJvpaQ92WZgEriskSePiGuBa2dsu7Tu9n8Cm2Y+Lt33ojm2Pwa8rJHXb7XyRJXSoa50m1lPu5AGz/F5VvL0EjPLmTnbSyLi/RGxCvhQRKxOv1ZFxKER0dW9fs0a8/QSM+t9PTETdajoOd1mli+NZJD/Ium0mRsj4qY2xJNr5QlPLzGz3iPp2Ii4P7376lm2dZ3SYD/lqSq1WtDX1xN/R5hZl2sk6f6DutsrSBa9uRXo2iWCmxERrnSbWa/6AnAyQERkk6I+TwPX7+TVULFABOyrVCn5vG1mObDgmSgiXl1/X9II8MG2RZRTE5UatcCVbjPrGZKOJ7lYfo2ks+t2rSYpsnSt0mDyv7dkJWEn3WbWec2ciUaBE1sdSN6Npb2BrnSbWQ95FvAqkoVx6gsse4G3dCKgVhlKx7smE0wGOxuMmRkNJN2SPs7+lST7gJNIVog8qJTTq+A9p9vMekVEfBn4sqSfTadJ9Yysuu1Z3WaWF42UbTfX3a4AV0XEv7cpntwaS+e9Dg260m1mPecCSU+pbEfEb3YimFYYGqyvdJuZdV4jPd2fSVeNPC7ddG97Q8qnrFriSreZ9aCv1t1eAbwWeKhDsbTEdKXbs7rNLCcaaS95MfAZ4AGS+a0jkt50sI0MLLvSbWY9KiK+UH9f0lXAv3YonJaYrnR7VreZ5UQjGeSHgVdExL0Ako4DrqKLR0k1w5VuMzuIbAKO7nQQSzHkSreZ5UwjSfdAlnADRMT3JQ20MaZcGp/y9BIz602S9pJcMK/0+8PAH3Y0qCUqFd3TbWb50tCFlJI+DXw2vf/rJIvjHFSmK92e021mPSYiVnU6hlYbGvT0EjPLl0aS7t8C3gb8DkkV5Cbgb9oZVB5N93S70m1mPUjSa4DT0rvfiIivznd83g3299EnV7rNLD8amV4yAXwk/TpoZdWSlQOudJtZb5F0CfB84B/STe+Q9MKIeHcHw1oSSQwV+13pNrPcmDPplvQV4DLgaxExNWPfM4DfAB6IiCvaGmFOlCcrlIoF+vrU6VDMzFrtTOCkiKgBSPoM8F2ga5NuSNoBXek2s7yYr9L9FuB3gb+UtAvYSTK/9VhgK/DX6WpmB4Wxyer03Fczsx60FtiV3l7TwThaplTs9/QSM8uNObPIiHgY+G/Af5N0DHAEMA58PyLKyxNefpQnKtNzX83Mesz7ge9KupHk2p3TaLDKLel04GNAAfhURFwyx3HPB74F/GpEfL4lUS+gVCx4TreZ5UZDpduIeIBkcZyDlivdZtarIuIqSd8g6esW8Idp4WVekgrAJ4BfAEaBWyRdExHfm+W4DwDXtTr2+QwV+xlze4mZ5URfpwPoFuXJCkNeGMfMekj6KSYAEbE9Iq6JiC9nCbcSG+d5ilOBrRFxX0RMAlcDZ81y3NuBLwA7Whf9wpKebreXmFk+OOlu0NhElZKXgDez3vIhSV+Q9EZJz5F0uKSjJb1U0vuAfweePc/jjwK21d0fTbdNk3QU8Frg0vkCkXSBpM2SNu/cubO5n2aGZHqJK91mlg8LJt2SXiXpoE/OXek2s14TEa8D/gR4FkmbyL8BXwbeDNwLvDQirp/nKWYb5xQz7v8lSbvKvCXniLgsIk6JiFPWr1/f4E8wv1LRlW4zy49GSrfnAB+T9AXg7yLi7jbHlEtjE+7pNrPek/Zf/3GTDx8FRurubwQemnHMKcDVkgAOA86UVImIf2ryNRs2NOhKt5nlRyOL47xe0mrgXODvJAXwd8BVEbG33QHmRXnS00vMzGa4Bdgk6VjgQZIiza/VHxARx2a3Jf098NXlSLhhf6U7IkiTfjOzjmmobSQi9pBcBHM1yejA1wLfkfT2NsaWK55eYmZ2oIioABeRTCW5G/hcRNwl6UJJF3Y2uqTSXakFk9Vap0MxM1u40i3p1cBvAs8EPgucGhE7JJVITrIfb2+InTdVrTFZqbmn28xshoi4Frh2xrZZL5qMiN9YjpgypfScXZ6oMtjv87eZdVYjle7XAR+NiOdGxIciYgdAukDOb7Y1upzILsTx9BIz60WSXihpKL39ekkfkfT0Tse1VEPpp5Oe1W1medBI0v0e4ObsjqSV2WzXiLihTXHlSjk9YbvSbWY96pNAWdLzSFYi/hFwZWdDWrpSeh2OJ5iYWR40knT/H6C+Ia6abjtojE240m1mPa0SEUGysM3HIuJjwKoOx7Rk05VuTzAxsxxoJIvsT1caAyAiJiUV2xhT7rjSbWY9bq+kdwOvB05Ll20f6HBMSzbd0+1Kt5nlQCOV7p2SXpPdkXQW8Gj7Qsqf6Uq3p5eYWW/6VWACOD9dAv4o4EOdDWnphgZd6Taz/Ggki7wQ+AdJf02y+tg24I1tjSpnskp3yZVuM+tNe0naSqqSjgOOB67qcExL5kq3meVJI4vj/BB4gaRDAB1MC+JkxtITthfHMbMedRPwIknDwA3AZpLq9693NKolmq50e3qJmeVAQ/0Skn4ReA6wIlvVKyL+vI1x5Up5Iqt0u73EzHqSIqIs6Xzg4xHxQUlbOh3UUtXP6TYz67QFe7olXUpS8Xg7SXvJ64CG5rdKOl3SvZK2Srp4lv3Dkr4k6XZJN0s6sW7fuyTdJelOSVdJWpFuf6+kByVtSb/ObPBnbdp0pdtJt5n1Jkn6WZLK9j+n27r+o72S53SbWY40ciHlz0XEG4HdEfFnwM8CIws9KL36/RPAGcAJwLmSTphx2B8BWyLiuSR94h9LH3sU8DvAKRFxIsnJ/5y6x300Ik5Kv66lzbJK90r3dJtZb3on8G7gS+ky7s8AbuxsSEtX6BMrBvrc021mudBI6XZf+r0s6UjgMeDYBh53KrA1Iu4DkHQ1yQzY79UdcwLwfoCIuEfSMZI21MW2UtIUUAIeauA122Jsskqx0Eexv5G/UczMuktEfBP4pqRVkg5Jz9u/0+m4WmGo2D99MbyZWSc1kkV+RdJakvFR3wEeoLGr2o8imXSSGU231bsNOBtA0qkkbSsbI+JB4C+AHwPbgSci4ut1j7sobUm5Ir3w5ykkXSBps6TNO3fubCDcuZUnK9Mrm5mZ9RpJPynpu8CdwPck3SrpOZ2OqxVKgwX3dJtZLsybdEvqA26IiMcj4gskSfHxEfGnDTy3ZtkWM+5fAgynF+y8HfguUEkT6bNIKupHAkOSXp8+5pPAM4GTSBLyD8/24hFxWUScEhGnrF+/voFw5zY2UXU/t5n1sr8Ffjcinh4RRwO/B1ze4ZhaYqjY755uM8uFeZPuiKhRl9RGxEREPNHgc49yYO/3Rma0iETEnog4LyJOIunpXg/cD7wcuD8idkbEFPBF4OfSxzwSEdU0tstJ2ljaqjxZ8YxuM+tlQxEx3cMdEd8AhjoXTuuUigX3dJtZLjTSXvJ1Sb+sbFZg424BNkk6Nl02/hzgmvoDJK2tW1L+zcBNEbGHpK3kBZJK6eu+DLg7fcwRdU/xWpKPQ9tqbLJKadCVbjPrWfdJ+pP0uppjJP13kgJI1xsa7PeKlGaWC41kkr9LUvGoSNpH0jYSEbF6vgdFREXSRcB1JNNHrkivir8w3X8p8GzgSklVkgssz0/3fVvS50l6yCskbSeXpU/9QUknkbSqPAC8tfEftznliQpDrnSbWe/6TeDPSD5VFMliOed1NKIWKRUL7Nw70ekwzMwaWpFyVbNPno7zu3bGtkvrbv8nsGmOx74HeM8s29/QbDzNGpussrZUXPhAM7MuFBG76ZFpJTO5p9vM8mLBpFvSabNtj4ibWh9OPpUnK14C3sx6jqSv8NQL3KdFxGuWMZy28PQSM8uLRtpL/qDu9gqSCxdvBV7alohyaGyi6iXgzawX/UWnA2g3V7rNLC8aaS95df19SSPAB9sWUQ6VJ93TbWa9J10Up6eViv3sm6pRrQWFvsXOAzAza51myrejwImtDiSvarWg7OklZtbDJN3BU9tMngA2A/8jIh5b/qhaI2sNLE9WWLVioMPRmNnBrJGe7o+z/2TcR7IozW1tjClXxqeSXkBXus2sh/0LUAX+d3r/HJIpJk8Afw+8evaH5V/WGlierDrpNrOOaqR8u7nudgW4KiL+vU3x5E7WC+hKt5n1sBdGxAvr7t8h6d8j4oV1qwF3pazS7VndZtZpjWSSnwf2RUQVQFJBUikiyu0NLR+yq95d6TazHnaIpJ+JiG8DSDoVOCTd19XZan2l28yskxpJum8gWZb9yfT+SuDrpMuy97rpSrenl5hZ73ozcIWkQ0jaSvYA50saAt7f0ciWKCuYuNJtZp3WSCa5IiKyhJuIeFJSqY0x5UpWHfGcbjPrVRFxC/CTktYAiojH63Z/rjNRtUbWGuhKt5l1Wl8Dx4xJOjm7I+mngfH2hZQvWXXElW4z61WS1kj6CMknm/8q6cNpAt71pivdntVtZh3WSCb5TuD/SHoovX8E8KttiyhnXOk2s4PAFcCdwH9N778B+Dvg7I5F1CLTlW6vSmlmHdbI4ji3SDoeeBZJr989ETHV9shyIqt0D7nSbWa965kR8ct19/9M0pZOBdNKrnSbWV4s2F4i6W3AUETcGRF3kFzl/tvtDy0fskp3ydNLzKx3jUv6L9kdSS+kR9oIPb3EzPKikZ7ut9RfVBMRu4G3tC2inMmqI0Oe021mvetC4BOSHpD0APDXwFs7G1JrFPv7GCjI00vMrOMayST7JCkiApI53UCxvWHlR3miSp9gsL+Rv0/MzLpPRNwGPE/S6vT+HknvBG7vaGAtUir2u9JtZh3XSCZ5HfA5SS+T9FLgKuBr7Q0rP8YmKwwV+5HU6VDMzNoqIvZExJ707u92NJgWGioWXOk2s45rpNL9h8AFwG+RXEj5deDydgaVJ+WJKiVPLjGzg0/PVBpKg650m1nnLVjpjohaRFwaEb+SXt1+F/Dx9oeWD1ml28zsIBOdDqBVhooFTy8xs45rKJuUdBJwLsl87vuBL7YxplwpT7rSbWa9SdJeZk+uBaxc5nDaplTs95xuM+u4OZNuSccB55Ak248B/0iyPPBLlim2XBibqHg1SjPrSRGxqtMxLIehwQLbn9jX6TDM7CA3XzZ5D/BvwKsjYiuApHctS1Q5Up6sctghB82wFjOznuPpJWaWB/P1dP8y8DBwo6TLJb2MHrqwplFjk5XpZYTNzOxAkk6XdK+krZIunmX/r0u6Pf36D0nPW+4YhwY9vcTMOm/OpDsivhQRvwocD3wDeBewQdInJb1imeLruPJEdXoZYTMz2y9dt+ETwBnACcC5kk6Ycdj9wM9HxHOB9wGXLW+UrnSbWT40Mr1kLCL+ISJeBWwEtgBPqWb0qrFJ93Sbmc3hVGBrRNwXEZPA1cBZ9QdExH+kKxkDfIvk/yPLKptekq7xZmbWEYtaZjEidkXE30bES9sVUJ5EBOXJKkOeXmJmNpujgG1190fTbXM5H/iX2XZIukDSZkmbd+7c2cIQkzndEbBvqtbS5zUzWwyvbT6PiUqNai1c6TYzm91s1/nMWk6W9BKSpPsPZ9sfEZdFxCkRccr69etbGCLTLYKe1W1mneSkex5ZD6B7us3MZjUKjNTd3wg8NPMgSc8FPgWcFRGPLVNs07LCiWd1m1knOemeR3a1u6eXmJnN6hZgk6RjJRVJ1na4pv4ASUeTLKj2hoj4fgdinG4RdKXbzDrJ2eQ89le6/TaZmc0UERVJFwHXAQXgioi4S9KF6f5LgT8FDgX+RhJAJSJOWc44pyvdTrrNrIOcTc4jq4p4GXgzs9lFxLXAtTO2XVp3+83Am5c7rnrTlW63l5hZB7m9ZB5Z/58r3WZm3cuVbjPLAyfd85iudPtCSjOzrpUVTlzpNrNOamvS3cDywMOSvpQuD3yzpBPr9r1L0l2S7pR0laQV6fZ1kq6X9IP0+3C74s+qIkO+kNLMrGtlLYKudJtZJ7Ut6W5weeA/ArakywO/EfhY+tijgN8BTomIE0ku0DknfczFwA0RsQm4gTaujumRgWZm3W+60u2l4M2sg9pZ6V5weWCSZPwGgIi4BzhG0oZ0Xz+wUlI/UGL/7NezgM+ktz8D/FK7foCsp9sjA83MuteKgT4kKE+40m1mndPOpLuR5YFvA84GkHQq8HRgY0Q8CPwF8GNgO/BERHw9fcyGiNgOkH4/fLYXb8WSwllP98oBV7rNzLqVJIaK/a50m1lHtTPpbmR54EuAYUlbgLcD3wUqaZ/2WcCxwJHAkKTXL+bFW7GkcHmyysqBAoW+2X4UMzPrFqViwT3dZtZR7eybWHB54IjYA5wHoGTVhPvTr1cC90fEznTfF4GfA/4X8IikIyJiu6QjgB3t+gHGJirT813NzKx7DQ32e3qJmXVUOyvdjSwPvDbdB8niCTelifiPgRdIKqXJ+MuAu9PjrgHelN5+E/Dldv0A5cnq9HxXMzPrXisHXOk2s85qW0bZ4PLAzwaulFQFvgecn+77tqTPA98BKiRtJ5elT30J8DlJ55Mk569r188wNlHxjG4zsx4wNFhwpdvMOqqtZdwGlgf+T2DTHI99D/CeWbY/RlL5brvyZNUzus3MekCp2M/j5clOh2FmBzGvSDmPsUlXus3MesHQYMHTS8yso5x0z6M8UZ1eVMHMzLpXqdjvOd1m1lFOuucxNlmZXj7YzMy611DRlW4z6ywn3fMoT7rSbWbWC0qD/Z5eYmYd5aR7Hp5eYmbWG4aKBaaqwWSl1ulQzOwg5aR7DpVqjYlKzXO6zcx6QHYud7XbzDrFSfccylNJ759XpDQz637Zudx93WbWKU6651BOF1FwpdvMrPtNV7o9wcTMOsRJ9xzG0o8gXek2M+t+rnSbWac56Z6DK91mZr3DlW4z6zQn3XOYrnR7eomZWdfLxr+60m1mneKkew7ZFe6lQVe6zcy6XbbQmaeXmFmnOOmew1jaXuJKt5lZ95uudE+40m1mneGkew6udJuZ9Q5Xus2s05x0z8GVbjOz3lEaSKeXuNJtZh3ipHsO05VuTy8xM+t6/YU+Bvv7XOk2s45x0j2HsckqAwVR7PdbZGbWC4YG+6cnU5mZLTdnlHMoT1Rc5TYz6yGlYmF6DQYzs+XmpHsOY5NV93ObmfWQoaIr3WbWOU6651CerHhyiZlZDykNFih7cRwz6xAn3XMYm3Cl28yslwwV+xnzMvBm1iEu5c6hPOmebjOzXlIqFnj0yYlOh2FtFBEASGrZc9ZqQaUW1CL5Xq0GlVqNagTVWpC+JFH3+tm2TKFPDBT6GCgk3/sLYqCvj76++eOc+doRQaFP9EkU+kRBWvA56kUk8dYi0nghuXVgzDPjB5j5lmb3hegT9ElIjb/3kb5/tTSeWgR9Ev19yc/Wyt9hXjirnMPYRJUj1w50OgwzM2uRdk8v2TdV5dEnJ3j0yUkee3KCx56cBODw1YMcvmoFG1YPMlwqLipJapUswanUJXG1+oSn7na1FkxVa+zZV+GJ8anprz3pV3b/yYkKtdmys6e8dprkBQckfEmilcRWi6BSTV67WguqdfcrtaBaq00nuNnjalF/P9mWyZLALBFMbkNfmrBGGgOxP1mupQlo1G2r1CXV7ZAk40kCDiQ/dy2mk+3FPE+SgENBmvX9yZLsdqt/v5V9R1QjDki0F5L9TIW+NBEvJN+z3+v086c/c/3vupAm7v2FPvrTxw8U+qbf7+Q5+2D6D4cs9v3/NgW85nlH8vITNrTsvXHSPQdXus3MestSp5eMTVS45+G93PPwHu59eC/bn9jHY3VJ9lgD/eL9feLwVYOsX72CDasG2bB6BWtWDjBRqTI+VWXfVI3xqSoTU3X3J6tMVKoHJIPAAUlUlizWasFULahUa1SqwVQt+b6YBG4+g/19rF45wJqVA6xa0U9hgWpkwHTSJSWJVP+MBCn73t8n+guqq3YmCVNf3/7qZ33yfEAyraQyKuqT6AOT82rdHxtZciXtr9RmVVoBpM85cMDr903HUf8lDqz6zkzkIHntqVowValRqdWYqiZ/2Eylv6fJag3gwJ9XyWsW+pj+niWv1TQxr8b+79Xa/j+gZr4/+3/G+vc+jbM+0UyPzW5P/x7r/jurvw/11fPsva6vXu//XfRlfxik1fns91aoi7UW2ScJ+//g2v8HWY2ptNpfqx34x9eBv+v9f7BNpX+wTVWD8alq8u+iFum/idr0f6MH3th/8wXPOHTe/74Xy1nlHMYmqwwNuqfbzKxXNFrprtWCH+8qc8/De7h7e5Jk3/PwXn70WHn6mFWD/Rw1vJJDDynyU0ev5dChQQ49pMj6Q5Lvhx4yyGGHFImAHXsn2LFnHzv2TvDInn08smeCHXv38aPHytzywC6eGJ9isL/AymKBlQMFBgf6WDlQYMVAcn+4NMBgf2G6Qn5gkrc/aRJJMjpQSKp4/VkrQ33Fb0a1sNB3YAW4kCZkA4U+1qwcYPXK/vT7AKtXDLBiwP9fNGuWk+45eE63mVlvKRUL7Juq8eIP3ZhWCJluZZjuna0FE5XadOVRgmMPG+LEI9fwKydv5PgjVvPsI1Zx1NqVDfecjqwrtfPHMrMu4axyDi8/YQM/edSaTodhZmYtcsaJR3DfzjECKIi6j+73tzT09YlioY9nrB/i+Ket5rgNq1jpSVZm1gJOuufwsXN+qtMhmJnlnqTTgY8BBeBTEXHJjP1K958JlIHfiIjvLHugwLOetoq/OtfndjPrDM/pNjOzpkgqAJ8AzgBOAM6VdMKMw84ANqVfFwCfXNYgzcxywkm3mZk161Rga0TcFxGTwNXAWTOOOQu4MhLfAtZKOmK5AzUz6zQn3WZm1qyjgG1190fTbYs9BkkXSNosafPOnTtbHqiZWac56TYzs2bNNr5j5kDoRo4hIi6LiFMi4pT169e3JDgzszxx0m1mZs0aBUbq7m8EHmriGDOzntfWpFvS6ZLulbRV0sWz7B+W9CVJt0u6WdKJ6fZnSdpS97VH0jvTfe+V9GDdvjPb+TOYmdmcbgE2STpWUhE4B7hmxjHXAG9U4gXAExGxfbkDNTPrtLaNDKy7qv0XSCodt0i6JiK+V3fYHwFbIuK1ko5Pj39ZRNwLnFT3PA8CX6p73Ecj4i/aFbuZmS0sIiqSLgKuIxkZeEVE3CXpwnT/pcC1JOMCt5KMDDyvU/GamXVSO+d0T1/VDiApu6q9Puk+AXg/QETcI+kYSRsi4pG6Y14G/DAiftTGWM3MrAkRcS1JYl2/7dK62wG8bbnjMjPLm3Ym3bNdsf4zM465DTgb+H+STgWeTtLvV590nwNcNeNxF0l6I7AZ+L2I2D3zxSVdQDITFuBJSfcuMv7DgEcX+Zg8cNzLr1tjd9zLq9m4n97qQPLu1ltvfVTSYgstB9t/F3nQrbE77uXVrXFDi8/bSooQrSfpdcArI+LN6f03AKdGxNvrjllNslLZTwF3AMcDb46I29L9RZILbp6TVb8lbSB5AwJ4H3BERPxmG+LfHBGntPp5281xL79ujd1xL69ujbtbdOv7261xQ/fG7riXV7fGDa2PvZ2V7gWvWI+IPaT9felSwfenX5kzgO/Ut5vU35Z0OfDVlkduZmZmZtZC7ZxesuBV7ZLWpvsA3gzclCbimXOZ0VoyYyWz1wJ3tjxyMzMzM7MWalulu8Gr2p8NXCmpSnKB5fnZ4yWVSCafvHXGU39Q0kkk7SUPzLK/VS5r0/O2m+Neft0au+NeXt0ad7fo1ve3W+OG7o3dcS+vbo0bWhx723q6zczMzMws4RUpzczMzMzazEm3mZmZmVmbOemeYaGl6/NG0gOS7pC0RdLmdNs6SddL+kH6fTgHcV4haYekO+u2zRmnpHenv4N7Jb2yM1HPGfd7JT2YvudbJJ1Zty8vcY9IulHS3ZLukvSOdHuu3/N54u6G93yFpJsl3ZbG/mfp9ly/572gm87bPme3l8/ZuYo91+97R87ZEeGv9Ivkgs8fAs8AiiSL95zQ6bgWiPkB4LAZ2z4IXJzevhj4QA7iPA04GbhzoThJViq9DRgEjk1/J4Ucxf1e4PdnOTZPcR8BnJzeXgV8P40v1+/5PHF3w3su4JD09gDwbeAFeX/Pu/2r287bPmd3JO5uOH905Tl7gdhz/b534pztSveBppeuj4hJIFu6vtucBXwmvf0Z4Jc6F0oiIm4Cds3YPFecZwFXR8RERNwPbCX53Sy7OeKeS57i3h4R30lv7wXuJlklNtfv+TxxzyUXcUOy3HlEPJneHUi/gpy/5z2gF87bPme3iM/Zy69bz9udOGc76T7QbEvXz/cfTh4E8HVJt0rKlr3fEBHbIfnHABzesejmN1ec3fB7uEjS7elHmdlHT7mMW9IxJKu+fpsues9nxA1d8J5LKkjaAuwAro+IrnrPu1S3vY8+Z3dG7s8fmW49Z0P3nbeX+5ztpPtAmmVb3mcqvjAiTiZZvfNtkk7rdEAtkPffwyeBZwInAduBD6fbcxe3pEOALwDvjAMXnnrKobNs61jss8TdFe95RFQj4iSSFXhPlXTiPIfnKvYu1m3vo8/Zy68rzh/Qveds6M7z9nKfs510H2jBpevzJiIeSr/vAL5E8lHHI0pX7ky/7+hchPOaK85c/x4i4pH0H2oNuJz9Hy/lKm5JAyQnwH+IiC+mm3P/ns8Wd7e855mIeBz4BnA6XfCed7mueh99zl5+3XL+6NZzNnT/eXu5ztlOug+04NL1eSJpSNKq7DbwCuBOkpjflB72JuDLnYlwQXPFeQ1wjqRBSccCm4CbOxDfrLJ/jKnXkrznkKO4JQn4NHB3RHykbleu3/O54u6S93y9pLXp7ZXAy4F7yPl73gO65rztc3ZndMn5oyvP2dC95+2OnLMXc9XlwfAFnEly5e0PgT/udDwLxPoMkitpbwPuyuIFDgVuAH6Qfl+Xg1ivIvl4aYrkr8Xz54sT+OP0d3AvcEbO4v4scAdwe/qP8Igcxv1fSD72uh3Ykn6dmff3fJ64u+E9fy7w3TTGO4E/Tbfn+j3vha9uOW/7nN2xuLvh/NGV5+wFYs/1+96Jc7aXgTczMzMzazO3l5iZmZmZtZmTbjMzMzOzNnPSbWZmZmbWZk66zczMzMzazEm3mZmZmVmbOem2riApJH247v7vS3pvi5777yX9Siuea4HXeZ2kuyXdOGP7MZJ+rd2vb2a2XHzONnsqJ93WLSaAsyUd1ulA6kkqLOLw84HfjoiXzNh+DDDrCVxSf5OhmZl1ks/ZZjM46bZuUQEuA941c8fMqoekJ9PvL5b0TUmfk/R9SZdI+nVJN0u6Q9Iz657m5ZL+LT3uVenjC5I+JOkWSbdLemvd894o6X+TDP6fGc+56fPfKekD6bY/JVlA4FJJH5rxkEuAF0naIuldkn5D0v+R9BXg6+kqdlekcXxX0lkLxHeEpJvS57tT0ouafM/NzJrlc7bP2TaD/yKzbvIJ4HZJH1zEY54HPBvYBdwHfCoiTpX0DuDtwDvT444Bfh54JnCjpJ8A3gg8ERHPlzQI/Lukr6fHnwqcGBH317+YpCOBDwA/DewmOQH/UkT8uaSXAr8fEZtnxHhxuj37H8dvAD8LPDcidkn6n8D/jYjfVLJk7c2S/hX49TniOxu4LiL+v7SqU1rE+2Vm1io+Z/ucbXWcdFvXiIg9kq4EfgcYb/Bht0TEdgBJPwSyE/AdQP1Hhp+LiBrwA0n3AccDrwCeW1eRWQNsAiaBm2eevFPPB74RETvT1/wH4DTgnxqMN3N9ROxKb78CeI2k30/vrwCOnie+W4ArJA0A/xQRWxb52mZmS+Zzts/ZdiAn3dZt/hL4DvB3ddsqpK1SkgQU6/ZN1N2u1d2vceB//zHjdQIQ8PaIuK5+h6QXA2NzxKcF4m9U/fML+OWIuHdGHLPGl+47DfhF4LOSPhQRV7YoLjOzxfhLfM7O4vA5+yDnnm7rKmkl4XMkF7hkHiD5aBDgLGCgiad+naS+tGfwGcC9wHXAb6XVByQdJ2logef5NvDzkg5LPyY8F/jmAo/ZC6yaZ/91wNvTEzaSfqpu+1Pik/R0YEdEXA58Gjh5gdc3M2sLn7N9zrb9XOm2bvRh4KK6+5cDX5Z0M3ADc1c05nMvyYl2A3BhROyT9CmSvsHvpCfPncAvzfckEbFd0ruBG0mqHddGxJcXeO3bgYqk24C/J+krrPc+kmrR7WkcDwCvAuaK78XAH0iaAp4k6XM0M+sUn7N9zjZAETM/oTEzMzMzs1Zye4mZmZmZWZs56TYzMzMzazMn3WZmZmZmbeak28zMzMyszZx0m5mZmZm1mZNuMzMzM7M2c9JtZmZmZtZm/z80xLncveiIiQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 864x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Import necessary libraries\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "logs = model_1.make_inspector().training_logs()\n",
    "\n",
    "# Plot the logs\n",
    "plt.figure(figsize=(12, 4))\n",
    "\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot([log.num_trees for log in logs], [log.evaluation.accuracy for log in logs])\n",
    "plt.xlabel(\"Number of trees\")\n",
    "plt.ylabel(\"Accuracy (out-of-bag)\")\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot([log.num_trees for log in logs], [log.evaluation.loss for log in logs])\n",
    "plt.xlabel(\"Number of trees\")\n",
    "plt.ylabel(\"Logloss (out-of-bag)\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "w1xzugBRhwuN"
   },
   "source": [
    "This dataset is small. You can see the model converging almost immediately.\n",
    "\n",
    "Let's use TensorBoard:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "id": "5R_m-JmvU9tu"
   },
   "outputs": [],
   "source": [
    "# This cell start TensorBoard that can be slow.\n",
    "# Load the TensorBoard notebook extension\n",
    "%load_ext tensorboard\n",
    "# Google internal version\n",
    "# %load_ext google3.learning.brain.tensorboard.notebook.extension"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "id": "j6mp7K6HWwqQ"
   },
   "outputs": [],
   "source": [
    "# Clear existing results (if any)\n",
    "!rm -fr \"/tmp/tensorboard_logs\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "id": "16NbLILYo124"
   },
   "outputs": [],
   "source": [
    "# Export the meta-data to tensorboard.\n",
    "model_1.make_inspector().export_to_tensorboard(\"/tmp/tensorboard_logs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "id": "TSsN6aTXW0LJ"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "      <iframe id=\"tensorboard-frame-145bd69f0a4adb94\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
       "      </iframe>\n",
       "      <script>\n",
       "        (function() {\n",
       "          const frame = document.getElementById(\"tensorboard-frame-145bd69f0a4adb94\");\n",
       "          const url = new URL(\"/proxy/6006/\", window.location);\n",
       "          const port = 0;\n",
       "          if (port) {\n",
       "            url.port = port;\n",
       "          }\n",
       "          frame.src = url;\n",
       "        })();\n",
       "      </script>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# docs_infra: no_execute\n",
    "# Start a tensorboard instance.\n",
    "%tensorboard --logdir \"/tmp/tensorboard_logs\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r_tlSccjZ8kE"
   },
   "source": [
    "<!-- <img class=\"tfo-display-only-on-site\" src=\"images/beginner_tensorboard.png\"/> -->\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "phTUr6F1t-_E"
   },
   "source": [
    "## Re-train the model with a different learning algorithm\n",
    "\n",
    "The learning algorithm is defined by the model class. For\n",
    "example, `tfdf.keras.RandomForestModel()` trains a Random Forest, while\n",
    "`tfdf.keras.GradientBoostedTreesModel()` trains a Gradient Boosted Decision\n",
    "Trees.\n",
    "\n",
    "The learning algorithms are listed by calling `tfdf.keras.get_all_models()` or in the\n",
    "[learner list](https://github.com/google/yggdrasil-decision-forests/manual/learners)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "id": "OwEAAzUZq2m8"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
        "[tensorflow_decision_forests.keras.RandomForestModel,\n",
        " tensorflow_decision_forests.keras.GradientBoostedTreesModel,\n",
        " tensorflow_decision_forests.keras.CartModel,\n",
        " tensorflow_decision_forests.keras.DistributedGradientBoostedTreesModel]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# List all algorithms\n",
    "tfdf.keras.get_all_models()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xmzvuI78voD4"
   },
   "source": [
    "The description of the learning algorithms and their hyper-parameters are also available in the [API reference](https://www.tensorflow.org/decision_forests/api_docs/python/tfdf) and builtin help:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "id": "2hONToBav4DE"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on class RandomForestModel in module tensorflow_decision_forests.keras:\n",
      "\n",
      "class RandomForestModel(tensorflow_decision_forests.keras.wrappers.RandomForestModel)\n",
      " |  RandomForestModel(*args, **kwargs)\n",
      " |  \n",
      " |  Random Forest learning algorithm.\n",
      " |  \n",
      " |  A Random Forest (https://www.stat.berkeley.edu/~breiman/randomforest2001.pdf)\n",
      " |  is a collection of deep CART decision trees trained independently and without\n",
      " |  pruning. Each tree is trained on a random subset of the original training \n",
      " |  dataset (sampled with replacement).\n",
      " |  \n",
      " |  The algorithm is unique in that it is robust to overfitting, even in extreme\n",
      " |  cases e.g. when there is more features than training examples.\n",
      " |  \n",
      " |  It is probably the most well-known of the Decision Forest training\n",
      " |  algorithms.\n",
      " |  \n",
      " |  Usage example:\n",
      " |  \n",
      " |  ```python\n",
      " |  import tensorflow_decision_forests as tfdf\n",
      " |  import pandas as pd\n",
      " |  \n",
      " |  dataset = pd.read_csv(\"project/dataset.csv\")\n",
      " |  tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(dataset, label=\"my_label\")\n",
      " |  \n",
      " |  model = tfdf.keras.RandomForestModel()\n",
      " |  model.fit(tf_dataset)\n",
      " |  \n",
      " |  print(model.summary())\n",
      " |  ```\n",
      " |  \n",
      " |  Attributes:\n",
      " |    task: Task to solve (e.g. Task.CLASSIFICATION, Task.REGRESSION,\n",
      " |      Task.RANKING).\n",
      " |    features: Specify the list and semantic of the input features of the model.\n",
      " |      If not specified, all the available features will be used. If specified\n",
      " |      and if `exclude_non_specified_features=True`, only the features in\n",
      " |      `features` will be used by the model. If \"preprocessing\" is used,\n",
      " |      `features` corresponds to the output of the preprocessing. In this case,\n",
      " |      it is recommended for the preprocessing to return a dictionary of tensors.\n",
      " |    exclude_non_specified_features: If true, only use the features specified in\n",
      " |      `features`.\n",
      " |    preprocessing: Functional keras model or @tf.function to apply on the input\n",
      " |      feature before the model to train. This preprocessing model can consume\n",
      " |      and return tensors, list of tensors or dictionary of tensors. If\n",
      " |      specified, the model only \"sees\" the output of the preprocessing (and not\n",
      " |      the raw input). Can be used to prepare the features or to stack multiple\n",
      " |      models on top of each other. Unlike preprocessing done in the tf.dataset,\n",
      " |      the operation in \"preprocessing\" are serialized with the model.\n",
      " |    postprocessing: Like \"preprocessing\" but applied on the model output.\n",
      " |    ranking_group: Only for `task=Task.RANKING`. Name of a tf.string feature that\n",
      " |      identifies queries in a query/document ranking task. The ranking group\n",
      " |      is not added automatically for the set of features if\n",
      " |      `exclude_non_specified_features=false`.\n",
      " |    temp_directory: Temporary directory used to store the model Assets after the\n",
      " |      training, and possibly as a work directory during the training. This\n",
      " |      temporary directory is necessary for the model to be exported after\n",
      " |      training e.g. `model.save(path)`. If not specified, `temp_directory` is\n",
      " |      set to a temporary directory using `tempfile.TemporaryDirectory`. This\n",
      " |      directory is deleted when the model python object is garbage-collected.\n",
      " |    verbose: If true, displays information about the training.\n",
      " |    hyperparameter_template: Override the default value of the hyper-parameters.\n",
      " |      If None (default) the default parameters of the library are used. If set,\n",
      " |      `default_hyperparameter_template` refers to one of the following\n",
      " |      preconfigured hyper-parameter sets. Those sets outperforms the default\n",
      " |      hyper-parameters (either generally or in specific scenarios).\n",
      " |      You can omit the version (e.g. remove \"@v5\") to use the last version of\n",
      " |      the template. In this case, the hyper-parameter can change in between\n",
      " |      releases (not recommended for training in production).\n",
      " |      - better_default@v1: A configuration that is generally better than the\n",
      " |        default parameters without being more expensive. The parameters are:\n",
      " |        winner_take_all=True.\n",
      " |      - benchmark_rank1@v1: Top ranking hyper-parameters on our benchmark\n",
      " |        slightly modified to run in reasonable time. The parameters are:\n",
      " |        winner_take_all=True, categorical_algorithm=\"RANDOM\",\n",
      " |        split_axis=\"SPARSE_OBLIQUE\", sparse_oblique_normalization=\"MIN_MAX\",\n",
      " |        sparse_oblique_num_projections_exponent=1.0.\n",
      " |  \n",
      " |    advanced_arguments: Advanced control of the model that most users won't need\n",
      " |      to use. See `AdvancedArguments` for details.\n",
      " |    num_threads: Number of threads used to train the model. Different learning\n",
      " |      algorithms use multi-threading differently and with different degree of\n",
      " |      efficiency. If specified, `num_threads` field of the\n",
      " |      `advanced_arguments.yggdrasil_deployment_config` has priority.\n",
      " |    name: The name of the model.\n",
      " |    adapt_bootstrap_size_ratio_for_maximum_training_duration: Control how the\n",
      " |      maximum training duration (if set) is applied. If false, the training\n",
      " |      stop when the time is used. If true, adapts the size of the sampled\n",
      " |      dataset used to train each tree such that `num_trees` will train within\n",
      " |      `maximum_training_duration`. Has no effect if there is no maximum\n",
      " |      training duration specified. Default: False.\n",
      " |    allow_na_conditions: If true, the tree training evaluates conditions of the\n",
      " |      type `X is NA` i.e. `X is missing`. Default: False.\n",
      " |    categorical_algorithm: How to learn splits on categorical attributes.\n",
      " |      - `CART`: CART algorithm. Find categorical splits of the form \"value \\\\in\n",
      " |        mask\". The solution is exact for binary classification, regression and\n",
      " |        ranking. It is approximated for multi-class classification. This is a\n",
      " |        good first algorithm to use. In case of overfitting (very small\n",
      " |        dataset, large dictionary), the \"random\" algorithm is a good\n",
      " |        alternative.\n",
      " |      - `ONE_HOT`: One-hot encoding. Find the optimal categorical split of the\n",
      " |        form \"attribute == param\". This method is similar (but more efficient)\n",
      " |        than converting converting each possible categorical value into a\n",
      " |        boolean feature. This method is available for comparison purpose and\n",
      " |        generally performs worse than other alternatives.\n",
      " |      - `RANDOM`: Best splits among a set of random candidate. Find the a\n",
      " |        categorical split of the form \"value \\\\in mask\" using a random search.\n",
      " |        This solution can be seen as an approximation of the CART algorithm.\n",
      " |        This method is a strong alternative to CART. This algorithm is inspired\n",
      " |        from section \"5.1 Categorical Variables\" of \"Random Forest\", 2001.\n",
      " |        Default: \"CART\".\n",
      " |    categorical_set_split_greedy_sampling: For categorical set splits e.g.\n",
      " |      texts. Probability for a categorical value to be a candidate for the\n",
      " |      positive set. The sampling is applied once per node (i.e. not at every\n",
      " |      step of the greedy optimization). Default: 0.1.\n",
      " |    categorical_set_split_max_num_items: For categorical set splits e.g. texts.\n",
      " |      Maximum number of items (prior to the sampling). If more items are\n",
      " |      available, the least frequent items are ignored. Changing this value is\n",
      " |      similar to change the \"max_vocab_count\" before loading the dataset, with\n",
      " |      the following exception: With `max_vocab_count`, all the remaining items\n",
      " |      are grouped in a special Out-of-vocabulary item. With `max_num_items`,\n",
      " |      this is not the case. Default: -1.\n",
      " |    categorical_set_split_min_item_frequency: For categorical set splits e.g.\n",
      " |      texts. Minimum number of occurrences of an item to be considered.\n",
      " |      Default: 1.\n",
      " |    compute_oob_performances: If true, compute the Out-of-bag evaluation (then\n",
      " |      available in the summary and model inspector). This evaluation is a cheap\n",
      " |      alternative to cross-validation evaluation. Default: True.\n",
      " |    compute_oob_variable_importances: If true, compute the Out-of-bag feature\n",
      " |      importance (then available in the summary and model inspector). Note that\n",
      " |      the OOB feature importance can be expensive to compute. Default: False.\n",
      " |    growing_strategy: How to grow the tree.\n",
      " |      - `LOCAL`: Each node is split independently of the other nodes. In other\n",
      " |        words, as long as a node satisfy the splits \"constraints (e.g. maximum\n",
      " |        depth, minimum number of observations), the node will be split. This is\n",
      " |        the \"classical\" way to grow decision trees.\n",
      " |      - `BEST_FIRST_GLOBAL`: The node with the best loss reduction among all\n",
      " |        the nodes of the tree is selected for splitting. This method is also\n",
      " |        called \"best first\" or \"leaf-wise growth\". See \"Best-first decision\n",
      " |        tree learning\", Shi and \"Additive logistic regression : A statistical\n",
      " |        view of boosting\", Friedman for more details. Default: \"LOCAL\".\n",
      " |    in_split_min_examples_check: Whether to check the `min_examples` constraint\n",
      " |      in the split search (i.e. splits leading to one child having less than\n",
      " |      `min_examples` examples are considered invalid) or before the split\n",
      " |      search (i.e. a node can be derived only if it contains more than\n",
      " |      `min_examples` examples). If false, there can be nodes with less than\n",
      " |      `min_examples` training examples. Default: True.\n",
      " |    max_depth: Maximum depth of the tree. `max_depth=1` means that all trees\n",
      " |      will be roots. Negative values are ignored. Default: 16.\n",
      " |    max_num_nodes: Maximum number of nodes in the tree. Set to -1 to disable\n",
      " |      this limit. Only available for `growing_strategy=BEST_FIRST_GLOBAL`.\n",
      " |      Default: None.\n",
      " |    maximum_training_duration_seconds: Maximum training duration of the model\n",
      " |      expressed in seconds. Each learning algorithm is free to use this\n",
      " |      parameter at it sees fit. Enabling maximum training duration makes the\n",
      " |      model training non-deterministic. Default: -1.0.\n",
      " |    min_examples: Minimum number of examples in a node. Default: 5.\n",
      " |    missing_value_policy: Method used to handle missing attribute values.\n",
      " |      - `GLOBAL_IMPUTATION`: Missing attribute values are imputed, with the\n",
      " |        mean (in case of numerical attribute) or the most-frequent-item (in\n",
      " |        case of categorical attribute) computed on the entire dataset (i.e. the\n",
      " |        information contained in the data spec).\n",
      " |      - `LOCAL_IMPUTATION`: Missing attribute values are imputed with the mean\n",
      " |        (numerical attribute) or most-frequent-item (in the case of categorical\n",
      " |        attribute) evaluated on the training examples in the current node.\n",
      " |      - `RANDOM_LOCAL_IMPUTATION`: Missing attribute values are imputed from\n",
      " |        randomly sampled values from the training examples in the current node.\n",
      " |        This method was proposed by Clinic et al. in \"Random Survival Forests\"\n",
      " |        (https://projecteuclid.org/download/pdfview_1/euclid.aoas/1223908043).\n",
      " |        Default: \"GLOBAL_IMPUTATION\".\n",
      " |    num_candidate_attributes: Number of unique valid attributes tested for each\n",
      " |      node. An attribute is valid if it has at least a valid split. If\n",
      " |      `num_candidate_attributes=0`, the value is set to the classical default\n",
      " |      value for Random Forest: `sqrt(number of input attributes)` in case of\n",
      " |      classification and `number_of_input_attributes / 3` in case of\n",
      " |      regression. If `num_candidate_attributes=-1`, all the attributes are\n",
      " |      tested. Default: 0.\n",
      " |    num_candidate_attributes_ratio: Ratio of attributes tested at each node. If\n",
      " |      set, it is equivalent to `num_candidate_attributes =\n",
      " |      number_of_input_features x num_candidate_attributes_ratio`. The possible\n",
      " |      values are between ]0, and 1] as well as -1. If not set or equal to -1,\n",
      " |      the `num_candidate_attributes` is used. Default: -1.0.\n",
      " |    num_trees: Number of individual decision trees. Increasing the number of\n",
      " |      trees can increase the quality of the model at the expense of size,\n",
      " |      training speed, and inference latency. Default: 300.\n",
      " |    sorting_strategy: How are sorted the numerical features in order to find\n",
      " |      the splits\n",
      " |      - PRESORT: The features are pre-sorted at the start of the training. This\n",
      " |        solution is faster but consumes much more memory than IN_NODE.\n",
      " |      - IN_NODE: The features are sorted just before being used in the node.\n",
      " |        This solution is slow but consumes little amount of memory.\n",
      " |      . Default: \"PRESORT\".\n",
      " |    sparse_oblique_normalization: For sparse oblique splits i.e.\n",
      " |      `split_axis=SPARSE_OBLIQUE`. Normalization applied on the features,\n",
      " |      before applying the sparse oblique projections.\n",
      " |      - `NONE`: No normalization.\n",
      " |      - `STANDARD_DEVIATION`: Normalize the feature by the estimated standard\n",
      " |        deviation on the entire train dataset. Also known as Z-Score\n",
      " |        normalization.\n",
      " |      - `MIN_MAX`: Normalize the feature by the range (i.e. max-min) estimated\n",
      " |        on the entire train dataset. Default: None.\n",
      " |    sparse_oblique_num_projections_exponent: For sparse oblique splits i.e.\n",
      " |      `split_axis=SPARSE_OBLIQUE`. Controls of the number of random projections\n",
      " |      to test at each node as `num_features^num_projections_exponent`. Default:\n",
      " |      None.\n",
      " |    sparse_oblique_projection_density_factor: For sparse oblique splits i.e.\n",
      " |      `split_axis=SPARSE_OBLIQUE`. Controls of the number of random projections\n",
      " |      to test at each node as `num_features^num_projections_exponent`. Default:\n",
      " |      None.\n",
      " |    split_axis: What structure of split to consider for numerical features.\n",
      " |      - `AXIS_ALIGNED`: Axis aligned splits (i.e. one condition at a time).\n",
      " |        This is the \"classical\" way to train a tree. Default value.\n",
      " |      - `SPARSE_OBLIQUE`: Sparse oblique splits (i.e. splits one a small number\n",
      " |        of features) from \"Sparse Projection Oblique Random Forests\", Tomita et\n",
      " |        al., 2020. Default: \"AXIS_ALIGNED\".\n",
      " |    winner_take_all: Control how classification trees vote. If true, each tree\n",
      " |      votes for one class. If false, each tree vote for a distribution of\n",
      " |      classes. winner_take_all_inference=false is often preferable. Default:\n",
      " |      True.\n",
      " |  \n",
      " |  Method resolution order:\n",
      " |      RandomForestModel\n",
      " |      tensorflow_decision_forests.keras.wrappers.RandomForestModel\n",
      " |      tensorflow_decision_forests.keras.core.CoreModel\n",
      " |      keras.engine.training.Model\n",
      " |      keras.engine.base_layer.Layer\n",
      " |      tensorflow.python.module.module.Module\n",
      " |      tensorflow.python.training.tracking.tracking.AutoTrackable\n",
      " |      tensorflow.python.training.tracking.base.Trackable\n",
      " |      keras.utils.version_utils.LayerVersionSelector\n",
      " |      keras.utils.version_utils.ModelVersionSelector\n",
      " |      builtins.object\n",
      " |  \n",
      " |  Methods inherited from tensorflow_decision_forests.keras.wrappers.RandomForestModel:\n",
      " |  \n",
      " |  __init__ = wrapper(*args, **kargs)\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Static methods inherited from tensorflow_decision_forests.keras.wrappers.RandomForestModel:\n",
      " |  \n",
      " |  predefined_hyperparameters() -> List[tensorflow_decision_forests.keras.core.HyperParameterTemplate]\n",
      " |      Returns a better than default set of hyper-parameters.\n",
      " |      \n",
      " |      They can be used directly with the `hyperparameter_template` argument of the\n",
      " |      model constructor.\n",
      " |      \n",
      " |      These hyper-parameters outperforms the default hyper-parameters (either\n",
      " |      generally or in specific scenarios). Like default hyper-parameters, existing\n",
      " |      pre-defined hyper-parameters cannot change.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Methods inherited from tensorflow_decision_forests.keras.core.CoreModel:\n",
      " |  \n",
      " |  call(self, inputs, training=False)\n",
      " |      Inference of the model.\n",
      " |      \n",
      " |      This method is used for prediction and evaluation of a trained model.\n",
      " |      \n",
      " |      Args:\n",
      " |        inputs: Input tensors.\n",
      " |        training: Is the model being trained. Always False.\n",
      " |      \n",
      " |      Returns:\n",
      " |        Model predictions.\n",
      " |  \n",
      " |  compile(self, metrics=None)\n",
      " |      Configure the model for training.\n",
      " |      \n",
      " |      Unlike for most Keras model, calling \"compile\" is optional before calling\n",
      " |      \"fit\".\n",
      " |      \n",
      " |      Args:\n",
      " |        metrics: Metrics to report during training.\n",
      " |      \n",
      " |      Raises:\n",
      " |        ValueError: Invalid arguments.\n",
      " |  \n",
      " |  evaluate(self, *args, **kwargs)\n",
      " |      Returns the loss value & metrics values for the model.\n",
      " |      \n",
      " |      See details on `keras.Model.evaluate`.\n",
      " |      \n",
      " |      Args:\n",
      " |        *args: Passed to `keras.Model.evaluate`.\n",
      " |        **kwargs: Passed to `keras.Model.evaluate`.  Scalar test loss (if the\n",
      " |          model has a single output and no metrics) or list of scalars (if the\n",
      " |          model has multiple outputs and/or metrics). See details in\n",
      " |          `keras.Model.evaluate`.\n",
      " |  \n",
      " |  fit(self, x=None, y=None, callbacks=None, **kwargs) -> keras.callbacks.History\n",
      " |      Trains the model.\n",
      " |      \n",
      " |      The following dataset formats are supported:\n",
      " |      \n",
      " |        1. \"x\" is a tf.data.Dataset containing a tuple \"(features, labels)\".\n",
      " |           \"features\" can be a dictionary a tensor, a list of tensors or a\n",
      " |           dictionary of tensors (recommended). \"labels\" is a tensor.\n",
      " |      \n",
      " |        2. \"x\" is a tensor, list of tensors or dictionary of tensors containing\n",
      " |           the input features. \"y\" is a tensor.\n",
      " |      \n",
      " |        3. \"x\" is a numpy-array, list of numpy-arrays or dictionary of\n",
      " |           numpy-arrays containing the input features. \"y\" is a numpy-array.\n",
      " |      \n",
      " |      Pandas Dataframe can be consumed with \"dataframe_to_tf_dataset\":\n",
      " |        dataset = pandas.Dataframe(...)\n",
      " |        model.fit(pd_dataframe_to_tf_dataset(dataset, label=\"my_label\"))\n",
      " |      \n",
      " |      Args:\n",
      " |        x: Training dataset (See details above for the supported formats).\n",
      " |        y: Label of the training dataset. Only used if \"x\" does not contains the\n",
      " |          labels.\n",
      " |        callbacks: Callbacks triggered during the training.\n",
      " |        **kwargs: Arguments passed to the core keras model's fit.\n",
      " |      \n",
      " |      Returns:\n",
      " |        A `History` object. Its `History.history` attribute is not yet\n",
      " |        implemented for decision forests algorithms, and will return empty.\n",
      " |        All other fields are filled as usual for `Keras.Mode.fit()`.\n",
      " |  \n",
      " |  make_inspector(self) -> tensorflow_decision_forests.component.inspector.inspector.AbstractInspector\n",
      " |      Creates an inspector to access the internal model structure.\n",
      " |      \n",
      " |      Usage example:\n",
      " |      \n",
      " |      ```python\n",
      " |      inspector = model.make_inspector()\n",
      " |      print(inspector.num_trees())\n",
      " |      print(inspector.variable_importances())\n",
      " |      ```\n",
      " |      \n",
      " |      Returns:\n",
      " |        A model inspector.\n",
      " |  \n",
      " |  make_predict_function(self)\n",
      " |      Prediction of the model (!= evaluation).\n",
      " |  \n",
      " |  make_test_function(self)\n",
      " |      Predictions for evaluation.\n",
      " |  \n",
      " |  save(self, filepath: str, overwrite: Union[bool, NoneType] = True, **kwargs)\n",
      " |      Saves the model as a TensorFlow SavedModel.\n",
      " |      \n",
      " |      The exported SavedModel contains a standalone Yggdrasil Decision Forests\n",
      " |      model in the \"assets\" sub-directory. The Yggdrasil model can be used\n",
      " |      directly using the Yggdrasil API. However, this model does not contain the\n",
      " |      \"preprocessing\" layer (if any).\n",
      " |      \n",
      " |      Args:\n",
      " |        filepath: Path to the output model.\n",
      " |        overwrite: If true, override an already existing model. If false, raise an\n",
      " |          error if a model already exist.\n",
      " |        **kwargs: Arguments passed to the core keras model's save.\n",
      " |  \n",
      " |  summary(self, line_length=None, positions=None, print_fn=None)\n",
      " |      Shows information about the model.\n",
      " |  \n",
      " |  train_step(self, data)\n",
      " |      Collects training examples.\n",
      " |  \n",
      " |  yggdrasil_model_path_tensor(self) -> Union[tensorflow.python.framework.ops.Tensor, NoneType]\n",
      " |      Gets the path to yggdrasil model, if available.\n",
      " |      \n",
      " |      The effective path can be obtained with:\n",
      " |      \n",
      " |      ```python\n",
      " |      yggdrasil_model_path_tensor().numpy().decode(\"utf-8\")\n",
      " |      ```\n",
      " |      \n",
      " |      Returns:\n",
      " |        Path to the Yggdrasil model.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Methods inherited from keras.engine.training.Model:\n",
      " |  \n",
      " |  __setattr__(self, name, value)\n",
      " |      Support self.foo = trackable syntax.\n",
      " |  \n",
      " |  build(self, input_shape)\n",
      " |      Builds the model based on input shapes received.\n",
      " |      \n",
      " |      This is to be used for subclassed models, which do not know at instantiation\n",
      " |      time what their inputs look like.\n",
      " |      \n",
      " |      This method only exists for users who want to call `model.build()` in a\n",
      " |      standalone way (as a substitute for calling the model on real data to\n",
      " |      build it). It will never be called by the framework (and thus it will\n",
      " |      never throw unexpected errors in an unrelated workflow).\n",
      " |      \n",
      " |      Args:\n",
      " |       input_shape: Single tuple, TensorShape, or list/dict of shapes, where\n",
      " |           shapes are tuples, integers, or TensorShapes.\n",
      " |      \n",
      " |      Raises:\n",
      " |        ValueError:\n",
      " |          1. In case of invalid user-provided data (not of type tuple,\n",
      " |             list, TensorShape, or dict).\n",
      " |          2. If the model requires call arguments that are agnostic\n",
      " |             to the input shapes (positional or kwarg in call signature).\n",
      " |          3. If not all layers were properly built.\n",
      " |          4. If float type inputs are not supported within the layers.\n",
      " |      \n",
      " |        In each of these cases, the user should build their model by calling it\n",
      " |        on real tensor data.\n",
      " |  \n",
      " |  evaluate_generator(self, generator, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)\n",
      " |      Evaluates the model on a data generator.\n",
      " |      \n",
      " |      DEPRECATED:\n",
      " |        `Model.evaluate` now supports generators, so there is no longer any need\n",
      " |        to use this endpoint.\n",
      " |  \n",
      " |  fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, validation_freq=1, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)\n",
      " |      Fits the model on data yielded batch-by-batch by a Python generator.\n",
      " |      \n",
      " |      DEPRECATED:\n",
      " |        `Model.fit` now supports generators, so there is no longer any need to use\n",
      " |        this endpoint.\n",
      " |  \n",
      " |  get_config(self)\n",
      " |      Returns the config of the layer.\n",
      " |      \n",
      " |      A layer config is a Python dictionary (serializable)\n",
      " |      containing the configuration of a layer.\n",
      " |      The same layer can be reinstantiated later\n",
      " |      (without its trained weights) from this configuration.\n",
      " |      \n",
      " |      The config of a layer does not include connectivity\n",
      " |      information, nor the layer class name. These are handled\n",
      " |      by `Network` (one layer of abstraction above).\n",
      " |      \n",
      " |      Note that `get_config()` does not guarantee to return a fresh copy of dict\n",
      " |      every time it is called. The callers should make a copy of the returned dict\n",
      " |      if they want to modify it.\n",
      " |      \n",
      " |      Returns:\n",
      " |          Python dictionary.\n",
      " |  \n",
      " |  get_layer(self, name=None, index=None)\n",
      " |      Retrieves a layer based on either its name (unique) or index.\n",
      " |      \n",
      " |      If `name` and `index` are both provided, `index` will take precedence.\n",
      " |      Indices are based on order of horizontal graph traversal (bottom-up).\n",
      " |      \n",
      " |      Args:\n",
      " |          name: String, name of layer.\n",
      " |          index: Integer, index of layer.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A layer instance.\n",
      " |      \n",
      " |      Raises:\n",
      " |          ValueError: In case of invalid layer name or index.\n",
      " |  \n",
      " |  get_weights(self)\n",
      " |      Retrieves the weights of the model.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A flat list of Numpy arrays.\n",
      " |  \n",
      " |  load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None)\n",
      " |      Loads all layer weights, either from a TensorFlow or an HDF5 weight file.\n",
      " |      \n",
      " |      If `by_name` is False weights are loaded based on the network's\n",
      " |      topology. This means the architecture should be the same as when the weights\n",
      " |      were saved.  Note that layers that don't have weights are not taken into\n",
      " |      account in the topological ordering, so adding or removing layers is fine as\n",
      " |      long as they don't have weights.\n",
      " |      \n",
      " |      If `by_name` is True, weights are loaded into layers only if they share the\n",
      " |      same name. This is useful for fine-tuning or transfer-learning models where\n",
      " |      some of the layers have changed.\n",
      " |      \n",
      " |      Only topological loading (`by_name=False`) is supported when loading weights\n",
      " |      from the TensorFlow format. Note that topological loading differs slightly\n",
      " |      between TensorFlow and HDF5 formats for user-defined classes inheriting from\n",
      " |      `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the\n",
      " |      TensorFlow format loads based on the object-local names of attributes to\n",
      " |      which layers are assigned in the `Model`'s constructor.\n",
      " |      \n",
      " |      Args:\n",
      " |          filepath: String, path to the weights file to load. For weight files in\n",
      " |              TensorFlow format, this is the file prefix (the same as was passed\n",
      " |              to `save_weights`). This can also be a path to a SavedModel\n",
      " |              saved from `model.save`.\n",
      " |          by_name: Boolean, whether to load weights by name or by topological\n",
      " |              order. Only topological loading is supported for weight files in\n",
      " |              TensorFlow format.\n",
      " |          skip_mismatch: Boolean, whether to skip loading of layers where there is\n",
      " |              a mismatch in the number of weights, or a mismatch in the shape of\n",
      " |              the weight (only valid when `by_name=True`).\n",
      " |          options: Optional `tf.train.CheckpointOptions` object that specifies\n",
      " |              options for loading weights.\n",
      " |      \n",
      " |      Returns:\n",
      " |          When loading a weight file in TensorFlow format, returns the same status\n",
      " |          object as `tf.train.Checkpoint.restore`. When graph building, restore\n",
      " |          ops are run automatically as soon as the network is built (on first call\n",
      " |          for user-defined classes inheriting from `Model`, immediately if it is\n",
      " |          already built).\n",
      " |      \n",
      " |          When loading weights in HDF5 format, returns `None`.\n",
      " |      \n",
      " |      Raises:\n",
      " |          ImportError: If h5py is not available and the weight file is in HDF5\n",
      " |              format.\n",
      " |          ValueError: If `skip_mismatch` is set to `True` when `by_name` is\n",
      " |            `False`.\n",
      " |  \n",
      " |  make_train_function(self, force=False)\n",
      " |      Creates a function that executes one step of training.\n",
      " |      \n",
      " |      This method can be overridden to support custom training logic.\n",
      " |      This method is called by `Model.fit` and `Model.train_on_batch`.\n",
      " |      \n",
      " |      Typically, this method directly controls `tf.function` and\n",
      " |      `tf.distribute.Strategy` settings, and delegates the actual training\n",
      " |      logic to `Model.train_step`.\n",
      " |      \n",
      " |      This function is cached the first time `Model.fit` or\n",
      " |      `Model.train_on_batch` is called. The cache is cleared whenever\n",
      " |      `Model.compile` is called. You can skip the cache and generate again the\n",
      " |      function with `force=True`.\n",
      " |      \n",
      " |      Args:\n",
      " |        force: Whether to regenerate the train function and skip the cached\n",
      " |          function if available.\n",
      " |      \n",
      " |      Returns:\n",
      " |        Function. The function created by this method should accept a\n",
      " |        `tf.data.Iterator`, and return a `dict` containing values that will\n",
      " |        be passed to `tf.keras.Callbacks.on_train_batch_end`, such as\n",
      " |        `{'loss': 0.2, 'accuracy': 0.7}`.\n",
      " |  \n",
      " |  predict(self, x, batch_size=None, verbose=0, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False)\n",
      " |      Generates output predictions for the input samples.\n",
      " |      \n",
      " |      Computation is done in batches. This method is designed for performance in\n",
      " |      large scale inputs. For small amount of inputs that fit in one batch,\n",
      " |      directly using `__call__` is recommended for faster execution, e.g.,\n",
      " |      `model(x)`, or `model(x, training=False)` if you have layers such as\n",
      " |      `tf.keras.layers.BatchNormalization` that behaves differently during\n",
      " |      inference. Also, note the fact that test loss is not affected by\n",
      " |      regularization layers like noise and dropout.\n",
      " |      \n",
      " |      Args:\n",
      " |          x: Input samples. It could be:\n",
      " |            - A Numpy array (or array-like), or a list of arrays\n",
      " |              (in case the model has multiple inputs).\n",
      " |            - A TensorFlow tensor, or a list of tensors\n",
      " |              (in case the model has multiple inputs).\n",
      " |            - A `tf.data` dataset.\n",
      " |            - A generator or `keras.utils.Sequence` instance.\n",
      " |            A more detailed description of unpacking behavior for iterator types\n",
      " |            (Dataset, generator, Sequence) is given in the `Unpacking behavior\n",
      " |            for iterator-like inputs` section of `Model.fit`.\n",
      " |          batch_size: Integer or `None`.\n",
      " |              Number of samples per batch.\n",
      " |              If unspecified, `batch_size` will default to 32.\n",
      " |              Do not specify the `batch_size` if your data is in the\n",
      " |              form of dataset, generators, or `keras.utils.Sequence` instances\n",
      " |              (since they generate batches).\n",
      " |          verbose: Verbosity mode, 0 or 1.\n",
      " |          steps: Total number of steps (batches of samples)\n",
      " |              before declaring the prediction round finished.\n",
      " |              Ignored with the default value of `None`. If x is a `tf.data`\n",
      " |              dataset and `steps` is None, `predict` will\n",
      " |              run until the input dataset is exhausted.\n",
      " |          callbacks: List of `keras.callbacks.Callback` instances.\n",
      " |              List of callbacks to apply during prediction.\n",
      " |              See [callbacks](/api_docs/python/tf/keras/callbacks).\n",
      " |          max_queue_size: Integer. Used for generator or `keras.utils.Sequence`\n",
      " |              input only. Maximum size for the generator queue.\n",
      " |              If unspecified, `max_queue_size` will default to 10.\n",
      " |          workers: Integer. Used for generator or `keras.utils.Sequence` input\n",
      " |              only. Maximum number of processes to spin up when using\n",
      " |              process-based threading. If unspecified, `workers` will default\n",
      " |              to 1.\n",
      " |          use_multiprocessing: Boolean. Used for generator or\n",
      " |              `keras.utils.Sequence` input only. If `True`, use process-based\n",
      " |              threading. If unspecified, `use_multiprocessing` will default to\n",
      " |              `False`. Note that because this implementation relies on\n",
      " |              multiprocessing, you should not pass non-picklable arguments to\n",
      " |              the generator as they can't be passed easily to children processes.\n",
      " |      \n",
      " |      See the discussion of `Unpacking behavior for iterator-like inputs` for\n",
      " |      `Model.fit`. Note that Model.predict uses the same interpretation rules as\n",
      " |      `Model.fit` and `Model.evaluate`, so inputs must be unambiguous for all\n",
      " |      three methods.\n",
      " |      \n",
      " |      Returns:\n",
      " |          Numpy array(s) of predictions.\n",
      " |      \n",
      " |      Raises:\n",
      " |          RuntimeError: If `model.predict` is wrapped in `tf.function`.\n",
      " |          ValueError: In case of mismatch between the provided\n",
      " |              input data and the model's expectations,\n",
      " |              or in case a stateful model receives a number of samples\n",
      " |              that is not a multiple of the batch size.\n",
      " |  \n",
      " |  predict_generator(self, generator, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)\n",
      " |      Generates predictions for the input samples from a data generator.\n",
      " |      \n",
      " |      DEPRECATED:\n",
      " |        `Model.predict` now supports generators, so there is no longer any need\n",
      " |        to use this endpoint.\n",
      " |  \n",
      " |  predict_on_batch(self, x)\n",
      " |      Returns predictions for a single batch of samples.\n",
      " |      \n",
      " |      Args:\n",
      " |          x: Input data. It could be:\n",
      " |            - A Numpy array (or array-like), or a list of arrays (in case the\n",
      " |                model has multiple inputs).\n",
      " |            - A TensorFlow tensor, or a list of tensors (in case the model has\n",
      " |                multiple inputs).\n",
      " |      \n",
      " |      Returns:\n",
      " |          Numpy array(s) of predictions.\n",
      " |      \n",
      " |      Raises:\n",
      " |          RuntimeError: If `model.predict_on_batch` is wrapped in `tf.function`.\n",
      " |          ValueError: In case of mismatch between given number of inputs and\n",
      " |            expectations of the model.\n",
      " |  \n",
      " |  predict_step(self, data)\n",
      " |      The logic for one inference step.\n",
      " |      \n",
      " |      This method can be overridden to support custom inference logic.\n",
      " |      This method is called by `Model.make_predict_function`.\n",
      " |      \n",
      " |      This method should contain the mathematical logic for one step of inference.\n",
      " |      This typically includes the forward pass.\n",
      " |      \n",
      " |      Configuration details for *how* this logic is run (e.g. `tf.function` and\n",
      " |      `tf.distribute.Strategy` settings), should be left to\n",
      " |      `Model.make_predict_function`, which can also be overridden.\n",
      " |      \n",
      " |      Args:\n",
      " |        data: A nested structure of `Tensor`s.\n",
      " |      \n",
      " |      Returns:\n",
      " |        The result of one inference step, typically the output of calling the\n",
      " |        `Model` on data.\n",
      " |  \n",
      " |  reset_metrics(self)\n",
      " |      Resets the state of all the metrics in the model.\n",
      " |      \n",
      " |      Examples:\n",
      " |      \n",
      " |      >>> inputs = tf.keras.layers.Input(shape=(3,))\n",
      " |      >>> outputs = tf.keras.layers.Dense(2)(inputs)\n",
      " |      >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)\n",
      " |      >>> model.compile(optimizer=\"Adam\", loss=\"mse\", metrics=[\"mae\"])\n",
      " |      \n",
      " |      >>> x = np.random.random((2, 3))\n",
      " |      >>> y = np.random.randint(0, 2, (2, 2))\n",
      " |      >>> _ = model.fit(x, y, verbose=0)\n",
      " |      >>> assert all(float(m.result()) for m in model.metrics)\n",
      " |      \n",
      " |      >>> model.reset_metrics()\n",
      " |      >>> assert all(float(m.result()) == 0 for m in model.metrics)\n",
      " |  \n",
      " |  reset_states(self)\n",
      " |  \n",
      " |  save_spec(self, dynamic_batch=True)\n",
      " |      Returns the `tf.TensorSpec` of call inputs as a tuple `(args, kwargs)`.\n",
      " |      \n",
      " |      This value is automatically defined after calling the model for the first\n",
      " |      time. Afterwards, you can use it when exporting the model for serving:\n",
      " |      \n",
      " |      ```python\n",
      " |      model = tf.keras.Model(...)\n",
      " |      \n",
      " |      @tf.function\n",
      " |      def serve(*args, **kwargs):\n",
      " |        outputs = model(*args, **kwargs)\n",
      " |        # Apply postprocessing steps, or add additional outputs.\n",
      " |        ...\n",
      " |        return outputs\n",
      " |      \n",
      " |      # arg_specs is `[tf.TensorSpec(...), ...]`. kwarg_specs, in this example, is\n",
      " |      # an empty dict since functional models do not use keyword arguments.\n",
      " |      arg_specs, kwarg_specs = model.save_spec()\n",
      " |      \n",
      " |      model.save(path, signatures={\n",
      " |        'serving_default': serve.get_concrete_function(*arg_specs, **kwarg_specs)\n",
      " |      })\n",
      " |      ```\n",
      " |      \n",
      " |      Args:\n",
      " |        dynamic_batch: Whether to set the batch sizes of all the returned\n",
      " |          `tf.TensorSpec` to `None`. (Note that when defining functional or\n",
      " |          Sequential models with `tf.keras.Input([...], batch_size=X)`, the\n",
      " |          batch size will always be preserved). Defaults to `True`.\n",
      " |      Returns:\n",
      " |        If the model inputs are defined, returns a tuple `(args, kwargs)`. All\n",
      " |        elements in `args` and `kwargs` are `tf.TensorSpec`.\n",
      " |        If the model inputs are not defined, returns `None`.\n",
      " |        The model inputs are automatically set when calling the model,\n",
      " |        `model.fit`, `model.evaluate` or `model.predict`.\n",
      " |  \n",
      " |  save_weights(self, filepath, overwrite=True, save_format=None, options=None)\n",
      " |      Saves all layer weights.\n",
      " |      \n",
      " |      Either saves in HDF5 or in TensorFlow format based on the `save_format`\n",
      " |      argument.\n",
      " |      \n",
      " |      When saving in HDF5 format, the weight file has:\n",
      " |        - `layer_names` (attribute), a list of strings\n",
      " |            (ordered names of model layers).\n",
      " |        - For every layer, a `group` named `layer.name`\n",
      " |            - For every such layer group, a group attribute `weight_names`,\n",
      " |                a list of strings\n",
      " |                (ordered names of weights tensor of the layer).\n",
      " |            - For every weight in the layer, a dataset\n",
      " |                storing the weight value, named after the weight tensor.\n",
      " |      \n",
      " |      When saving in TensorFlow format, all objects referenced by the network are\n",
      " |      saved in the same format as `tf.train.Checkpoint`, including any `Layer`\n",
      " |      instances or `Optimizer` instances assigned to object attributes. For\n",
      " |      networks constructed from inputs and outputs using `tf.keras.Model(inputs,\n",
      " |      outputs)`, `Layer` instances used by the network are tracked/saved\n",
      " |      automatically. For user-defined classes which inherit from `tf.keras.Model`,\n",
      " |      `Layer` instances must be assigned to object attributes, typically in the\n",
      " |      constructor. See the documentation of `tf.train.Checkpoint` and\n",
      " |      `tf.keras.Model` for details.\n",
      " |      \n",
      " |      While the formats are the same, do not mix `save_weights` and\n",
      " |      `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be\n",
      " |      loaded using `Model.load_weights`. Checkpoints saved using\n",
      " |      `tf.train.Checkpoint.save` should be restored using the corresponding\n",
      " |      `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over\n",
      " |      `save_weights` for training checkpoints.\n",
      " |      \n",
      " |      The TensorFlow format matches objects and variables by starting at a root\n",
      " |      object, `self` for `save_weights`, and greedily matching attribute\n",
      " |      names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this\n",
      " |      is the `Checkpoint` even if the `Checkpoint` has a model attached. This\n",
      " |      means saving a `tf.keras.Model` using `save_weights` and loading into a\n",
      " |      `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match\n",
      " |      the `Model`'s variables. See the [guide to training\n",
      " |      checkpoints](https://www.tensorflow.org/guide/checkpoint) for details\n",
      " |      on the TensorFlow format.\n",
      " |      \n",
      " |      Args:\n",
      " |          filepath: String or PathLike, path to the file to save the weights to.\n",
      " |              When saving in TensorFlow format, this is the prefix used for\n",
      " |              checkpoint files (multiple files are generated). Note that the '.h5'\n",
      " |              suffix causes weights to be saved in HDF5 format.\n",
      " |          overwrite: Whether to silently overwrite any existing file at the\n",
      " |              target location, or provide the user with a manual prompt.\n",
      " |          save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or\n",
      " |              '.keras' will default to HDF5 if `save_format` is `None`. Otherwise\n",
      " |              `None` defaults to 'tf'.\n",
      " |          options: Optional `tf.train.CheckpointOptions` object that specifies\n",
      " |              options for saving weights.\n",
      " |      \n",
      " |      Raises:\n",
      " |          ImportError: If h5py is not available when attempting to save in HDF5\n",
      " |              format.\n",
      " |          ValueError: For invalid/unknown format arguments.\n",
      " |  \n",
      " |  test_on_batch(self, x, y=None, sample_weight=None, reset_metrics=True, return_dict=False)\n",
      " |      Test the model on a single batch of samples.\n",
      " |      \n",
      " |      Args:\n",
      " |          x: Input data. It could be:\n",
      " |            - A Numpy array (or array-like), or a list of arrays (in case the\n",
      " |                model has multiple inputs).\n",
      " |            - A TensorFlow tensor, or a list of tensors (in case the model has\n",
      " |                multiple inputs).\n",
      " |            - A dict mapping input names to the corresponding array/tensors, if\n",
      " |                the model has named inputs.\n",
      " |          y: Target data. Like the input data `x`, it could be either Numpy\n",
      " |            array(s) or TensorFlow tensor(s). It should be consistent with `x`\n",
      " |            (you cannot have Numpy inputs and tensor targets, or inversely).\n",
      " |          sample_weight: Optional array of the same length as x, containing\n",
      " |            weights to apply to the model's loss for each sample. In the case of\n",
      " |            temporal data, you can pass a 2D array with shape (samples,\n",
      " |            sequence_length), to apply a different weight to every timestep of\n",
      " |            every sample.\n",
      " |          reset_metrics: If `True`, the metrics returned will be only for this\n",
      " |            batch. If `False`, the metrics will be statefully accumulated across\n",
      " |            batches.\n",
      " |          return_dict: If `True`, loss and metric results are returned as a dict,\n",
      " |            with each key being the name of the metric. If `False`, they are\n",
      " |            returned as a list.\n",
      " |      \n",
      " |      Returns:\n",
      " |          Scalar test loss (if the model has a single output and no metrics)\n",
      " |          or list of scalars (if the model has multiple outputs\n",
      " |          and/or metrics). The attribute `model.metrics_names` will give you\n",
      " |          the display labels for the scalar outputs.\n",
      " |      \n",
      " |      Raises:\n",
      " |          RuntimeError: If `model.test_on_batch` is wrapped in `tf.function`.\n",
      " |          ValueError: In case of invalid user-provided arguments.\n",
      " |  \n",
      " |  test_step(self, data)\n",
      " |      The logic for one evaluation step.\n",
      " |      \n",
      " |      This method can be overridden to support custom evaluation logic.\n",
      " |      This method is called by `Model.make_test_function`.\n",
      " |      \n",
      " |      This function should contain the mathematical logic for one step of\n",
      " |      evaluation.\n",
      " |      This typically includes the forward pass, loss calculation, and metrics\n",
      " |      updates.\n",
      " |      \n",
      " |      Configuration details for *how* this logic is run (e.g. `tf.function` and\n",
      " |      `tf.distribute.Strategy` settings), should be left to\n",
      " |      `Model.make_test_function`, which can also be overridden.\n",
      " |      \n",
      " |      Args:\n",
      " |        data: A nested structure of `Tensor`s.\n",
      " |      \n",
      " |      Returns:\n",
      " |        A `dict` containing values that will be passed to\n",
      " |        `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the\n",
      " |        values of the `Model`'s metrics are returned.\n",
      " |  \n",
      " |  to_json(self, **kwargs)\n",
      " |      Returns a JSON string containing the network configuration.\n",
      " |      \n",
      " |      To load a network from a JSON save file, use\n",
      " |      `keras.models.model_from_json(json_string, custom_objects={})`.\n",
      " |      \n",
      " |      Args:\n",
      " |          **kwargs: Additional keyword arguments\n",
      " |              to be passed to `json.dumps()`.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A JSON string.\n",
      " |  \n",
      " |  to_yaml(self, **kwargs)\n",
      " |      Returns a yaml string containing the network configuration.\n",
      " |      \n",
      " |      Note: Since TF 2.6, this method is no longer supported and will raise a\n",
      " |      RuntimeError.\n",
      " |      \n",
      " |      To load a network from a yaml save file, use\n",
      " |      `keras.models.model_from_yaml(yaml_string, custom_objects={})`.\n",
      " |      \n",
      " |      `custom_objects` should be a dictionary mapping\n",
      " |      the names of custom losses / layers / etc to the corresponding\n",
      " |      functions / classes.\n",
      " |      \n",
      " |      Args:\n",
      " |          **kwargs: Additional keyword arguments\n",
      " |              to be passed to `yaml.dump()`.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A YAML string.\n",
      " |      \n",
      " |      Raises:\n",
      " |          RuntimeError: announces that the method poses a security risk\n",
      " |            (Use the safer `safe_load` function instead of `unsafe_load` when possible)\n",
      " |  \n",
      " |  train_on_batch(self, x, y=None, sample_weight=None, class_weight=None, reset_metrics=True, return_dict=False)\n",
      " |      Runs a single gradient update on a single batch of data.\n",
      " |      \n",
      " |      Args:\n",
      " |          x: Input data. It could be:\n",
      " |            - A Numpy array (or array-like), or a list of arrays\n",
      " |                (in case the model has multiple inputs).\n",
      " |            - A TensorFlow tensor, or a list of tensors\n",
      " |                (in case the model has multiple inputs).\n",
      " |            - A dict mapping input names to the corresponding array/tensors,\n",
      " |                if the model has named inputs.\n",
      " |          y: Target data. Like the input data `x`, it could be either Numpy\n",
      " |            array(s) or TensorFlow tensor(s). It should be consistent with `x`\n",
      " |            (you cannot have Numpy inputs and tensor targets, or inversely).\n",
      " |          sample_weight: Optional array of the same length as x, containing\n",
      " |            weights to apply to the model's loss for each sample. In the case of\n",
      " |            temporal data, you can pass a 2D array with shape (samples,\n",
      " |            sequence_length), to apply a different weight to every timestep of\n",
      " |            every sample.\n",
      " |          class_weight: Optional dictionary mapping class indices (integers) to a\n",
      " |            weight (float) to apply to the model's loss for the samples from this\n",
      " |            class during training. This can be useful to tell the model to \"pay\n",
      " |            more attention\" to samples from an under-represented class.\n",
      " |          reset_metrics: If `True`, the metrics returned will be only for this\n",
      " |            batch. If `False`, the metrics will be statefully accumulated across\n",
      " |            batches.\n",
      " |          return_dict: If `True`, loss and metric results are returned as a dict,\n",
      " |            with each key being the name of the metric. If `False`, they are\n",
      " |            returned as a list.\n",
      " |      \n",
      " |      Returns:\n",
      " |          Scalar training loss\n",
      " |          (if the model has a single output and no metrics)\n",
      " |          or list of scalars (if the model has multiple outputs\n",
      " |          and/or metrics). The attribute `model.metrics_names` will give you\n",
      " |          the display labels for the scalar outputs.\n",
      " |      \n",
      " |      Raises:\n",
      " |        RuntimeError: If `model.train_on_batch` is wrapped in `tf.function`.\n",
      " |        ValueError: In case of invalid user-provided arguments.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Class methods inherited from keras.engine.training.Model:\n",
      " |  \n",
      " |  from_config(config, custom_objects=None) from builtins.type\n",
      " |      Creates a layer from its config.\n",
      " |      \n",
      " |      This method is the reverse of `get_config`,\n",
      " |      capable of instantiating the same layer from the config\n",
      " |      dictionary. It does not handle layer connectivity\n",
      " |      (handled by Network), nor weights (handled by `set_weights`).\n",
      " |      \n",
      " |      Args:\n",
      " |          config: A Python dictionary, typically the\n",
      " |              output of get_config.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A layer instance.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Static methods inherited from keras.engine.training.Model:\n",
      " |  \n",
      " |  __new__(cls, *args, **kwargs)\n",
      " |      Create and return a new object.  See help(type) for accurate signature.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Data descriptors inherited from keras.engine.training.Model:\n",
      " |  \n",
      " |  distribute_strategy\n",
      " |      The `tf.distribute.Strategy` this model was created under.\n",
      " |  \n",
      " |  layers\n",
      " |  \n",
      " |  metrics\n",
      " |      Returns the model's metrics added using `compile`, `add_metric` APIs.\n",
      " |      \n",
      " |      Note: Metrics passed to `compile()` are available only after a `keras.Model`\n",
      " |      has been trained/evaluated on actual data.\n",
      " |      \n",
      " |      Examples:\n",
      " |      \n",
      " |      >>> inputs = tf.keras.layers.Input(shape=(3,))\n",
      " |      >>> outputs = tf.keras.layers.Dense(2)(inputs)\n",
      " |      >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)\n",
      " |      >>> model.compile(optimizer=\"Adam\", loss=\"mse\", metrics=[\"mae\"])\n",
      " |      >>> [m.name for m in model.metrics]\n",
      " |      []\n",
      " |      \n",
      " |      >>> x = np.random.random((2, 3))\n",
      " |      >>> y = np.random.randint(0, 2, (2, 2))\n",
      " |      >>> model.fit(x, y)\n",
      " |      >>> [m.name for m in model.metrics]\n",
      " |      ['loss', 'mae']\n",
      " |      \n",
      " |      >>> inputs = tf.keras.layers.Input(shape=(3,))\n",
      " |      >>> d = tf.keras.layers.Dense(2, name='out')\n",
      " |      >>> output_1 = d(inputs)\n",
      " |      >>> output_2 = d(inputs)\n",
      " |      >>> model = tf.keras.models.Model(\n",
      " |      ...    inputs=inputs, outputs=[output_1, output_2])\n",
      " |      >>> model.add_metric(\n",
      " |      ...    tf.reduce_sum(output_2), name='mean', aggregation='mean')\n",
      " |      >>> model.compile(optimizer=\"Adam\", loss=\"mse\", metrics=[\"mae\", \"acc\"])\n",
      " |      >>> model.fit(x, (y, y))\n",
      " |      >>> [m.name for m in model.metrics]\n",
      " |      ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',\n",
      " |      'out_1_acc', 'mean']\n",
      " |  \n",
      " |  metrics_names\n",
      " |      Returns the model's display labels for all outputs.\n",
      " |      \n",
      " |      Note: `metrics_names` are available only after a `keras.Model` has been\n",
      " |      trained/evaluated on actual data.\n",
      " |      \n",
      " |      Examples:\n",
      " |      \n",
      " |      >>> inputs = tf.keras.layers.Input(shape=(3,))\n",
      " |      >>> outputs = tf.keras.layers.Dense(2)(inputs)\n",
      " |      >>> model = tf.keras.models.Model(inputs=inputs, outputs=outputs)\n",
      " |      >>> model.compile(optimizer=\"Adam\", loss=\"mse\", metrics=[\"mae\"])\n",
      " |      >>> model.metrics_names\n",
      " |      []\n",
      " |      \n",
      " |      >>> x = np.random.random((2, 3))\n",
      " |      >>> y = np.random.randint(0, 2, (2, 2))\n",
      " |      >>> model.fit(x, y)\n",
      " |      >>> model.metrics_names\n",
      " |      ['loss', 'mae']\n",
      " |      \n",
      " |      >>> inputs = tf.keras.layers.Input(shape=(3,))\n",
      " |      >>> d = tf.keras.layers.Dense(2, name='out')\n",
      " |      >>> output_1 = d(inputs)\n",
      " |      >>> output_2 = d(inputs)\n",
      " |      >>> model = tf.keras.models.Model(\n",
      " |      ...    inputs=inputs, outputs=[output_1, output_2])\n",
      " |      >>> model.compile(optimizer=\"Adam\", loss=\"mse\", metrics=[\"mae\", \"acc\"])\n",
      " |      >>> model.fit(x, (y, y))\n",
      " |      >>> model.metrics_names\n",
      " |      ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae',\n",
      " |      'out_1_acc']\n",
      " |  \n",
      " |  non_trainable_weights\n",
      " |      List of all non-trainable weights tracked by this layer.\n",
      " |      \n",
      " |      Non-trainable weights are *not* updated during training. They are expected\n",
      " |      to be updated manually in `call()`.\n",
      " |      \n",
      " |      Returns:\n",
      " |        A list of non-trainable variables.\n",
      " |  \n",
      " |  run_eagerly\n",
      " |      Settable attribute indicating whether the model should run eagerly.\n",
      " |      \n",
      " |      Running eagerly means that your model will be run step by step,\n",
      " |      like Python code. Your model might run slower, but it should become easier\n",
      " |      for you to debug it by stepping into individual layer calls.\n",
      " |      \n",
      " |      By default, we will attempt to compile your model to a static graph to\n",
      " |      deliver the best execution performance.\n",
      " |      \n",
      " |      Returns:\n",
      " |        Boolean, whether the model should run eagerly.\n",
      " |  \n",
      " |  state_updates\n",
      " |      Deprecated, do NOT use!\n",
      " |      \n",
      " |      Returns the `updates` from all layers that are stateful.\n",
      " |      \n",
      " |      This is useful for separating training updates and\n",
      " |      state updates, e.g. when we need to update a layer's internal state\n",
      " |      during prediction.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A list of update ops.\n",
      " |  \n",
      " |  trainable_weights\n",
      " |      List of all trainable weights tracked by this layer.\n",
      " |      \n",
      " |      Trainable weights are updated via gradient descent during training.\n",
      " |      \n",
      " |      Returns:\n",
      " |        A list of trainable variables.\n",
      " |  \n",
      " |  weights\n",
      " |      Returns the list of all layer variables/weights.\n",
      " |      \n",
      " |      Note: This will not track the weights of nested `tf.Modules` that are not\n",
      " |      themselves Keras layers.\n",
      " |      \n",
      " |      Returns:\n",
      " |        A list of variables.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Methods inherited from keras.engine.base_layer.Layer:\n",
      " |  \n",
      " |  __call__(self, *args, **kwargs)\n",
      " |      Wraps `call`, applying pre- and post-processing steps.\n",
      " |      \n",
      " |      Args:\n",
      " |        *args: Positional arguments to be passed to `self.call`.\n",
      " |        **kwargs: Keyword arguments to be passed to `self.call`.\n",
      " |      \n",
      " |      Returns:\n",
      " |        Output tensor(s).\n",
      " |      \n",
      " |      Note:\n",
      " |        - The following optional keyword arguments are reserved for specific uses:\n",
      " |          * `training`: Boolean scalar tensor of Python boolean indicating\n",
      " |            whether the `call` is meant for training or inference.\n",
      " |          * `mask`: Boolean input mask.\n",
      " |        - If the layer's `call` method takes a `mask` argument (as some Keras\n",
      " |          layers do), its default value will be set to the mask generated\n",
      " |          for `inputs` by the previous layer (if `input` did come from\n",
      " |          a layer that generated a corresponding mask, i.e. if it came from\n",
      " |          a Keras layer with masking support.\n",
      " |        - If the layer is not built, the method will call `build`.\n",
      " |      \n",
      " |      Raises:\n",
      " |        ValueError: if the layer's `call` method returns None (an invalid value).\n",
      " |        RuntimeError: if `super().__init__()` was not called in the constructor.\n",
      " |  \n",
      " |  __delattr__(self, name)\n",
      " |      Implement delattr(self, name).\n",
      " |  \n",
      " |  __getstate__(self)\n",
      " |  \n",
      " |  __setstate__(self, state)\n",
      " |  \n",
      " |  add_loss(self, losses, **kwargs)\n",
      " |      Add loss tensor(s), potentially dependent on layer inputs.\n",
      " |      \n",
      " |      Some losses (for instance, activity regularization losses) may be dependent\n",
      " |      on the inputs passed when calling a layer. Hence, when reusing the same\n",
      " |      layer on different inputs `a` and `b`, some entries in `layer.losses` may\n",
      " |      be dependent on `a` and some on `b`. This method automatically keeps track\n",
      " |      of dependencies.\n",
      " |      \n",
      " |      This method can be used inside a subclassed layer or model's `call`\n",
      " |      function, in which case `losses` should be a Tensor or list of Tensors.\n",
      " |      \n",
      " |      Example:\n",
      " |      \n",
      " |      ```python\n",
      " |      class MyLayer(tf.keras.layers.Layer):\n",
      " |        def call(self, inputs):\n",
      " |          self.add_loss(tf.abs(tf.reduce_mean(inputs)))\n",
      " |          return inputs\n",
      " |      ```\n",
      " |      \n",
      " |      This method can also be called directly on a Functional Model during\n",
      " |      construction. In this case, any loss Tensors passed to this Model must\n",
      " |      be symbolic and be able to be traced back to the model's `Input`s. These\n",
      " |      losses become part of the model's topology and are tracked in `get_config`.\n",
      " |      \n",
      " |      Example:\n",
      " |      \n",
      " |      ```python\n",
      " |      inputs = tf.keras.Input(shape=(10,))\n",
      " |      x = tf.keras.layers.Dense(10)(inputs)\n",
      " |      outputs = tf.keras.layers.Dense(1)(x)\n",
      " |      model = tf.keras.Model(inputs, outputs)\n",
      " |      # Activity regularization.\n",
      " |      model.add_loss(tf.abs(tf.reduce_mean(x)))\n",
      " |      ```\n",
      " |      \n",
      " |      If this is not the case for your loss (if, for example, your loss references\n",
      " |      a `Variable` of one of the model's layers), you can wrap your loss in a\n",
      " |      zero-argument lambda. These losses are not tracked as part of the model's\n",
      " |      topology since they can't be serialized.\n",
      " |      \n",
      " |      Example:\n",
      " |      \n",
      " |      ```python\n",
      " |      inputs = tf.keras.Input(shape=(10,))\n",
      " |      d = tf.keras.layers.Dense(10)\n",
      " |      x = d(inputs)\n",
      " |      outputs = tf.keras.layers.Dense(1)(x)\n",
      " |      model = tf.keras.Model(inputs, outputs)\n",
      " |      # Weight regularization.\n",
      " |      model.add_loss(lambda: tf.reduce_mean(d.kernel))\n",
      " |      ```\n",
      " |      \n",
      " |      Args:\n",
      " |        losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses\n",
      " |          may also be zero-argument callables which create a loss tensor.\n",
      " |        **kwargs: Additional keyword arguments for backward compatibility.\n",
      " |          Accepted values:\n",
      " |            inputs - Deprecated, will be automatically inferred.\n",
      " |  \n",
      " |  add_metric(self, value, name=None, **kwargs)\n",
      " |      Adds metric tensor to the layer.\n",
      " |      \n",
      " |      This method can be used inside the `call()` method of a subclassed layer\n",
      " |      or model.\n",
      " |      \n",
      " |      ```python\n",
      " |      class MyMetricLayer(tf.keras.layers.Layer):\n",
      " |        def __init__(self):\n",
      " |          super(MyMetricLayer, self).__init__(name='my_metric_layer')\n",
      " |          self.mean = tf.keras.metrics.Mean(name='metric_1')\n",
      " |      \n",
      " |        def call(self, inputs):\n",
      " |          self.add_metric(self.mean(inputs))\n",
      " |          self.add_metric(tf.reduce_sum(inputs), name='metric_2')\n",
      " |          return inputs\n",
      " |      ```\n",
      " |      \n",
      " |      This method can also be called directly on a Functional Model during\n",
      " |      construction. In this case, any tensor passed to this Model must\n",
      " |      be symbolic and be able to be traced back to the model's `Input`s. These\n",
      " |      metrics become part of the model's topology and are tracked when you\n",
      " |      save the model via `save()`.\n",
      " |      \n",
      " |      ```python\n",
      " |      inputs = tf.keras.Input(shape=(10,))\n",
      " |      x = tf.keras.layers.Dense(10)(inputs)\n",
      " |      outputs = tf.keras.layers.Dense(1)(x)\n",
      " |      model = tf.keras.Model(inputs, outputs)\n",
      " |      model.add_metric(math_ops.reduce_sum(x), name='metric_1')\n",
      " |      ```\n",
      " |      \n",
      " |      Note: Calling `add_metric()` with the result of a metric object on a\n",
      " |      Functional Model, as shown in the example below, is not supported. This is\n",
      " |      because we cannot trace the metric result tensor back to the model's inputs.\n",
      " |      \n",
      " |      ```python\n",
      " |      inputs = tf.keras.Input(shape=(10,))\n",
      " |      x = tf.keras.layers.Dense(10)(inputs)\n",
      " |      outputs = tf.keras.layers.Dense(1)(x)\n",
      " |      model = tf.keras.Model(inputs, outputs)\n",
      " |      model.add_metric(tf.keras.metrics.Mean()(x), name='metric_1')\n",
      " |      ```\n",
      " |      \n",
      " |      Args:\n",
      " |        value: Metric tensor.\n",
      " |        name: String metric name.\n",
      " |        **kwargs: Additional keyword arguments for backward compatibility.\n",
      " |          Accepted values:\n",
      " |          `aggregation` - When the `value` tensor provided is not the result of\n",
      " |          calling a `keras.Metric` instance, it will be aggregated by default\n",
      " |          using a `keras.Metric.Mean`.\n",
      " |  \n",
      " |  add_update(self, updates, inputs=None)\n",
      " |      Add update op(s), potentially dependent on layer inputs.\n",
      " |      \n",
      " |      Weight updates (for instance, the updates of the moving mean and variance\n",
      " |      in a BatchNormalization layer) may be dependent on the inputs passed\n",
      " |      when calling a layer. Hence, when reusing the same layer on\n",
      " |      different inputs `a` and `b`, some entries in `layer.updates` may be\n",
      " |      dependent on `a` and some on `b`. This method automatically keeps track\n",
      " |      of dependencies.\n",
      " |      \n",
      " |      This call is ignored when eager execution is enabled (in that case, variable\n",
      " |      updates are run on the fly and thus do not need to be tracked for later\n",
      " |      execution).\n",
      " |      \n",
      " |      Args:\n",
      " |        updates: Update op, or list/tuple of update ops, or zero-arg callable\n",
      " |          that returns an update op. A zero-arg callable should be passed in\n",
      " |          order to disable running the updates by setting `trainable=False`\n",
      " |          on this Layer, when executing in Eager mode.\n",
      " |        inputs: Deprecated, will be automatically inferred.\n",
      " |  \n",
      " |  add_variable(self, *args, **kwargs)\n",
      " |      Deprecated, do NOT use! Alias for `add_weight`.\n",
      " |  \n",
      " |  add_weight(self, name=None, shape=None, dtype=None, initializer=None, regularizer=None, trainable=None, constraint=None, use_resource=None, synchronization=<VariableSynchronization.AUTO: 0>, aggregation=<VariableAggregationV2.NONE: 0>, **kwargs)\n",
      " |      Adds a new variable to the layer.\n",
      " |      \n",
      " |      Args:\n",
      " |        name: Variable name.\n",
      " |        shape: Variable shape. Defaults to scalar if unspecified.\n",
      " |        dtype: The type of the variable. Defaults to `self.dtype`.\n",
      " |        initializer: Initializer instance (callable).\n",
      " |        regularizer: Regularizer instance (callable).\n",
      " |        trainable: Boolean, whether the variable should be part of the layer's\n",
      " |          \"trainable_variables\" (e.g. variables, biases)\n",
      " |          or \"non_trainable_variables\" (e.g. BatchNorm mean and variance).\n",
      " |          Note that `trainable` cannot be `True` if `synchronization`\n",
      " |          is set to `ON_READ`.\n",
      " |        constraint: Constraint instance (callable).\n",
      " |        use_resource: Whether to use `ResourceVariable`.\n",
      " |        synchronization: Indicates when a distributed a variable will be\n",
      " |          aggregated. Accepted values are constants defined in the class\n",
      " |          `tf.VariableSynchronization`. By default the synchronization is set to\n",
      " |          `AUTO` and the current `DistributionStrategy` chooses\n",
      " |          when to synchronize. If `synchronization` is set to `ON_READ`,\n",
      " |          `trainable` must not be set to `True`.\n",
      " |        aggregation: Indicates how a distributed variable will be aggregated.\n",
      " |          Accepted values are constants defined in the class\n",
      " |          `tf.VariableAggregation`.\n",
      " |        **kwargs: Additional keyword arguments. Accepted values are `getter`,\n",
      " |          `collections`, `experimental_autocast` and `caching_device`.\n",
      " |      \n",
      " |      Returns:\n",
      " |        The variable created.\n",
      " |      \n",
      " |      Raises:\n",
      " |        ValueError: When giving unsupported dtype and no initializer or when\n",
      " |          trainable has been set to True with synchronization set as `ON_READ`.\n",
      " |  \n",
      " |  apply(self, inputs, *args, **kwargs)\n",
      " |      Deprecated, do NOT use!\n",
      " |      \n",
      " |      This is an alias of `self.__call__`.\n",
      " |      \n",
      " |      Args:\n",
      " |        inputs: Input tensor(s).\n",
      " |        *args: additional positional arguments to be passed to `self.call`.\n",
      " |        **kwargs: additional keyword arguments to be passed to `self.call`.\n",
      " |      \n",
      " |      Returns:\n",
      " |        Output tensor(s).\n",
      " |  \n",
      " |  compute_mask(self, inputs, mask=None)\n",
      " |      Computes an output mask tensor.\n",
      " |      \n",
      " |      Args:\n",
      " |          inputs: Tensor or list of tensors.\n",
      " |          mask: Tensor or list of tensors.\n",
      " |      \n",
      " |      Returns:\n",
      " |          None or a tensor (or list of tensors,\n",
      " |              one per output tensor of the layer).\n",
      " |  \n",
      " |  compute_output_shape(self, input_shape)\n",
      " |      Computes the output shape of the layer.\n",
      " |      \n",
      " |      If the layer has not been built, this method will call `build` on the\n",
      " |      layer. This assumes that the layer will later be used with inputs that\n",
      " |      match the input shape provided here.\n",
      " |      \n",
      " |      Args:\n",
      " |          input_shape: Shape tuple (tuple of integers)\n",
      " |              or list of shape tuples (one per output tensor of the layer).\n",
      " |              Shape tuples can include None for free dimensions,\n",
      " |              instead of an integer.\n",
      " |      \n",
      " |      Returns:\n",
      " |          An input shape tuple.\n",
      " |  \n",
      " |  compute_output_signature(self, input_signature)\n",
      " |      Compute the output tensor signature of the layer based on the inputs.\n",
      " |      \n",
      " |      Unlike a TensorShape object, a TensorSpec object contains both shape\n",
      " |      and dtype information for a tensor. This method allows layers to provide\n",
      " |      output dtype information if it is different from the input dtype.\n",
      " |      For any layer that doesn't implement this function,\n",
      " |      the framework will fall back to use `compute_output_shape`, and will\n",
      " |      assume that the output dtype matches the input dtype.\n",
      " |      \n",
      " |      Args:\n",
      " |        input_signature: Single TensorSpec or nested structure of TensorSpec\n",
      " |          objects, describing a candidate input for the layer.\n",
      " |      \n",
      " |      Returns:\n",
      " |        Single TensorSpec or nested structure of TensorSpec objects, describing\n",
      " |          how the layer would transform the provided input.\n",
      " |      \n",
      " |      Raises:\n",
      " |        TypeError: If input_signature contains a non-TensorSpec object.\n",
      " |  \n",
      " |  count_params(self)\n",
      " |      Count the total number of scalars composing the weights.\n",
      " |      \n",
      " |      Returns:\n",
      " |          An integer count.\n",
      " |      \n",
      " |      Raises:\n",
      " |          ValueError: if the layer isn't yet built\n",
      " |            (in which case its weights aren't yet defined).\n",
      " |  \n",
      " |  finalize_state(self)\n",
      " |      Finalizes the layers state after updating layer weights.\n",
      " |      \n",
      " |      This function can be subclassed in a layer and will be called after updating\n",
      " |      a layer weights. It can be overridden to finalize any additional layer state\n",
      " |      after a weight update.\n",
      " |  \n",
      " |  get_input_at(self, node_index)\n",
      " |      Retrieves the input tensor(s) of a layer at a given node.\n",
      " |      \n",
      " |      Args:\n",
      " |          node_index: Integer, index of the node\n",
      " |              from which to retrieve the attribute.\n",
      " |              E.g. `node_index=0` will correspond to the\n",
      " |              first input node of the layer.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A tensor (or list of tensors if the layer has multiple inputs).\n",
      " |      \n",
      " |      Raises:\n",
      " |        RuntimeError: If called in Eager mode.\n",
      " |  \n",
      " |  get_input_mask_at(self, node_index)\n",
      " |      Retrieves the input mask tensor(s) of a layer at a given node.\n",
      " |      \n",
      " |      Args:\n",
      " |          node_index: Integer, index of the node\n",
      " |              from which to retrieve the attribute.\n",
      " |              E.g. `node_index=0` will correspond to the\n",
      " |              first time the layer was called.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A mask tensor\n",
      " |          (or list of tensors if the layer has multiple inputs).\n",
      " |  \n",
      " |  get_input_shape_at(self, node_index)\n",
      " |      Retrieves the input shape(s) of a layer at a given node.\n",
      " |      \n",
      " |      Args:\n",
      " |          node_index: Integer, index of the node\n",
      " |              from which to retrieve the attribute.\n",
      " |              E.g. `node_index=0` will correspond to the\n",
      " |              first time the layer was called.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A shape tuple\n",
      " |          (or list of shape tuples if the layer has multiple inputs).\n",
      " |      \n",
      " |      Raises:\n",
      " |        RuntimeError: If called in Eager mode.\n",
      " |  \n",
      " |  get_losses_for(self, inputs)\n",
      " |      Deprecated, do NOT use!\n",
      " |      \n",
      " |      Retrieves losses relevant to a specific set of inputs.\n",
      " |      \n",
      " |      Args:\n",
      " |        inputs: Input tensor or list/tuple of input tensors.\n",
      " |      \n",
      " |      Returns:\n",
      " |        List of loss tensors of the layer that depend on `inputs`.\n",
      " |  \n",
      " |  get_output_at(self, node_index)\n",
      " |      Retrieves the output tensor(s) of a layer at a given node.\n",
      " |      \n",
      " |      Args:\n",
      " |          node_index: Integer, index of the node\n",
      " |              from which to retrieve the attribute.\n",
      " |              E.g. `node_index=0` will correspond to the\n",
      " |              first output node of the layer.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A tensor (or list of tensors if the layer has multiple outputs).\n",
      " |      \n",
      " |      Raises:\n",
      " |        RuntimeError: If called in Eager mode.\n",
      " |  \n",
      " |  get_output_mask_at(self, node_index)\n",
      " |      Retrieves the output mask tensor(s) of a layer at a given node.\n",
      " |      \n",
      " |      Args:\n",
      " |          node_index: Integer, index of the node\n",
      " |              from which to retrieve the attribute.\n",
      " |              E.g. `node_index=0` will correspond to the\n",
      " |              first time the layer was called.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A mask tensor\n",
      " |          (or list of tensors if the layer has multiple outputs).\n",
      " |  \n",
      " |  get_output_shape_at(self, node_index)\n",
      " |      Retrieves the output shape(s) of a layer at a given node.\n",
      " |      \n",
      " |      Args:\n",
      " |          node_index: Integer, index of the node\n",
      " |              from which to retrieve the attribute.\n",
      " |              E.g. `node_index=0` will correspond to the\n",
      " |              first time the layer was called.\n",
      " |      \n",
      " |      Returns:\n",
      " |          A shape tuple\n",
      " |          (or list of shape tuples if the layer has multiple outputs).\n",
      " |      \n",
      " |      Raises:\n",
      " |        RuntimeError: If called in Eager mode.\n",
      " |  \n",
      " |  get_updates_for(self, inputs)\n",
      " |      Deprecated, do NOT use!\n",
      " |      \n",
      " |      Retrieves updates relevant to a specific set of inputs.\n",
      " |      \n",
      " |      Args:\n",
      " |        inputs: Input tensor or list/tuple of input tensors.\n",
      " |      \n",
      " |      Returns:\n",
      " |        List of update ops of the layer that depend on `inputs`.\n",
      " |  \n",
      " |  set_weights(self, weights)\n",
      " |      Sets the weights of the layer, from NumPy arrays.\n",
      " |      \n",
      " |      The weights of a layer represent the state of the layer. This function\n",
      " |      sets the weight values from numpy arrays. The weight values should be\n",
      " |      passed in the order they are created by the layer. Note that the layer's\n",
      " |      weights must be instantiated before calling this function, by calling\n",
      " |      the layer.\n",
      " |      \n",
      " |      For example, a `Dense` layer returns a list of two values: the kernel matrix\n",
      " |      and the bias vector. These can be used to set the weights of another\n",
      " |      `Dense` layer:\n",
      " |      \n",
      " |      >>> layer_a = tf.keras.layers.Dense(1,\n",
      " |      ...   kernel_initializer=tf.constant_initializer(1.))\n",
      " |      >>> a_out = layer_a(tf.convert_to_tensor([[1., 2., 3.]]))\n",
      " |      >>> layer_a.get_weights()\n",
      " |      [array([[1.],\n",
      " |             [1.],\n",
      " |             [1.]], dtype=float32), array([0.], dtype=float32)]\n",
      " |      >>> layer_b = tf.keras.layers.Dense(1,\n",
      " |      ...   kernel_initializer=tf.constant_initializer(2.))\n",
      " |      >>> b_out = layer_b(tf.convert_to_tensor([[10., 20., 30.]]))\n",
      " |      >>> layer_b.get_weights()\n",
      " |      [array([[2.],\n",
      " |             [2.],\n",
      " |             [2.]], dtype=float32), array([0.], dtype=float32)]\n",
      " |      >>> layer_b.set_weights(layer_a.get_weights())\n",
      " |      >>> layer_b.get_weights()\n",
      " |      [array([[1.],\n",
      " |             [1.],\n",
      " |             [1.]], dtype=float32), array([0.], dtype=float32)]\n",
      " |      \n",
      " |      Args:\n",
      " |        weights: a list of NumPy arrays. The number\n",
      " |          of arrays and their shape must match\n",
      " |          number of the dimensions of the weights\n",
      " |          of the layer (i.e. it should match the\n",
      " |          output of `get_weights`).\n",
      " |      \n",
      " |      Raises:\n",
      " |        ValueError: If the provided weights list does not match the\n",
      " |          layer's specifications.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Data descriptors inherited from keras.engine.base_layer.Layer:\n",
      " |  \n",
      " |  activity_regularizer\n",
      " |      Optional regularizer function for the output of this layer.\n",
      " |  \n",
      " |  compute_dtype\n",
      " |      The dtype of the layer's computations.\n",
      " |      \n",
      " |      This is equivalent to `Layer.dtype_policy.compute_dtype`. Unless\n",
      " |      mixed precision is used, this is the same as `Layer.dtype`, the dtype of\n",
      " |      the weights.\n",
      " |      \n",
      " |      Layers automatically cast their inputs to the compute dtype, which causes\n",
      " |      computations and the output to be in the compute dtype as well. This is done\n",
      " |      by the base Layer class in `Layer.__call__`, so you do not have to insert\n",
      " |      these casts if implementing your own layer.\n",
      " |      \n",
      " |      Layers often perform certain internal computations in higher precision when\n",
      " |      `compute_dtype` is float16 or bfloat16 for numeric stability. The output\n",
      " |      will still typically be float16 or bfloat16 in such cases.\n",
      " |      \n",
      " |      Returns:\n",
      " |        The layer's compute dtype.\n",
      " |  \n",
      " |  dtype\n",
      " |      The dtype of the layer weights.\n",
      " |      \n",
      " |      This is equivalent to `Layer.dtype_policy.variable_dtype`. Unless\n",
      " |      mixed precision is used, this is the same as `Layer.compute_dtype`, the\n",
      " |      dtype of the layer's computations.\n",
      " |  \n",
      " |  dtype_policy\n",
      " |      The dtype policy associated with this layer.\n",
      " |      \n",
      " |      This is an instance of a `tf.keras.mixed_precision.Policy`.\n",
      " |  \n",
      " |  dynamic\n",
      " |      Whether the layer is dynamic (eager-only); set in the constructor.\n",
      " |  \n",
      " |  inbound_nodes\n",
      " |      Deprecated, do NOT use! Only for compatibility with external Keras.\n",
      " |  \n",
      " |  input\n",
      " |      Retrieves the input tensor(s) of a layer.\n",
      " |      \n",
      " |      Only applicable if the layer has exactly one input,\n",
      " |      i.e. if it is connected to one incoming layer.\n",
      " |      \n",
      " |      Returns:\n",
      " |          Input tensor or list of input tensors.\n",
      " |      \n",
      " |      Raises:\n",
      " |        RuntimeError: If called in Eager mode.\n",
      " |        AttributeError: If no inbound nodes are found.\n",
      " |  \n",
      " |  input_mask\n",
      " |      Retrieves the input mask tensor(s) of a layer.\n",
      " |      \n",
      " |      Only applicable if the layer has exactly one inbound node,\n",
      " |      i.e. if it is connected to one incoming layer.\n",
      " |      \n",
      " |      Returns:\n",
      " |          Input mask tensor (potentially None) or list of input\n",
      " |          mask tensors.\n",
      " |      \n",
      " |      Raises:\n",
      " |          AttributeError: if the layer is connected to\n",
      " |          more than one incoming layers.\n",
      " |  \n",
      " |  input_shape\n",
      " |      Retrieves the input shape(s) of a layer.\n",
      " |      \n",
      " |      Only applicable if the layer has exactly one input,\n",
      " |      i.e. if it is connected to one incoming layer, or if all inputs\n",
      " |      have the same shape.\n",
      " |      \n",
      " |      Returns:\n",
      " |          Input shape, as an integer shape tuple\n",
      " |          (or list of shape tuples, one tuple per input tensor).\n",
      " |      \n",
      " |      Raises:\n",
      " |          AttributeError: if the layer has no defined input_shape.\n",
      " |          RuntimeError: if called in Eager mode.\n",
      " |  \n",
      " |  input_spec\n",
      " |      `InputSpec` instance(s) describing the input format for this layer.\n",
      " |      \n",
      " |      When you create a layer subclass, you can set `self.input_spec` to enable\n",
      " |      the layer to run input compatibility checks when it is called.\n",
      " |      Consider a `Conv2D` layer: it can only be called on a single input tensor\n",
      " |      of rank 4. As such, you can set, in `__init__()`:\n",
      " |      \n",
      " |      ```python\n",
      " |      self.input_spec = tf.keras.layers.InputSpec(ndim=4)\n",
      " |      ```\n",
      " |      \n",
      " |      Now, if you try to call the layer on an input that isn't rank 4\n",
      " |      (for instance, an input of shape `(2,)`, it will raise a nicely-formatted\n",
      " |      error:\n",
      " |      \n",
      " |      ```\n",
      " |      ValueError: Input 0 of layer conv2d is incompatible with the layer:\n",
      " |      expected ndim=4, found ndim=1. Full shape received: [2]\n",
      " |      ```\n",
      " |      \n",
      " |      Input checks that can be specified via `input_spec` include:\n",
      " |      - Structure (e.g. a single input, a list of 2 inputs, etc)\n",
      " |      - Shape\n",
      " |      - Rank (ndim)\n",
      " |      - Dtype\n",
      " |      \n",
      " |      For more information, see `tf.keras.layers.InputSpec`.\n",
      " |      \n",
      " |      Returns:\n",
      " |        A `tf.keras.layers.InputSpec` instance, or nested structure thereof.\n",
      " |  \n",
      " |  losses\n",
      " |      List of losses added using the `add_loss()` API.\n",
      " |      \n",
      " |      Variable regularization tensors are created when this property is accessed,\n",
      " |      so it is eager safe: accessing `losses` under a `tf.GradientTape` will\n",
      " |      propagate gradients back to the corresponding variables.\n",
      " |      \n",
      " |      Examples:\n",
      " |      \n",
      " |      >>> class MyLayer(tf.keras.layers.Layer):\n",
      " |      ...   def call(self, inputs):\n",
      " |      ...     self.add_loss(tf.abs(tf.reduce_mean(inputs)))\n",
      " |      ...     return inputs\n",
      " |      >>> l = MyLayer()\n",
      " |      >>> l(np.ones((10, 1)))\n",
      " |      >>> l.losses\n",
      " |      [1.0]\n",
      " |      \n",
      " |      >>> inputs = tf.keras.Input(shape=(10,))\n",
      " |      >>> x = tf.keras.layers.Dense(10)(inputs)\n",
      " |      >>> outputs = tf.keras.layers.Dense(1)(x)\n",
      " |      >>> model = tf.keras.Model(inputs, outputs)\n",
      " |      >>> # Activity regularization.\n",
      " |      >>> len(model.losses)\n",
      " |      0\n",
      " |      >>> model.add_loss(tf.abs(tf.reduce_mean(x)))\n",
      " |      >>> len(model.losses)\n",
      " |      1\n",
      " |      \n",
      " |      >>> inputs = tf.keras.Input(shape=(10,))\n",
      " |      >>> d = tf.keras.layers.Dense(10, kernel_initializer='ones')\n",
      " |      >>> x = d(inputs)\n",
      " |      >>> outputs = tf.keras.layers.Dense(1)(x)\n",
      " |      >>> model = tf.keras.Model(inputs, outputs)\n",
      " |      >>> # Weight regularization.\n",
      " |      >>> model.add_loss(lambda: tf.reduce_mean(d.kernel))\n",
      " |      >>> model.losses\n",
      " |      [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>]\n",
      " |      \n",
      " |      Returns:\n",
      " |        A list of tensors.\n",
      " |  \n",
      " |  name\n",
      " |      Name of the layer (string), set in the constructor.\n",
      " |  \n",
      " |  non_trainable_variables\n",
      " |      Sequence of non-trainable variables owned by this module and its submodules.\n",
      " |      \n",
      " |      Note: this method uses reflection to find variables on the current instance\n",
      " |      and submodules. For performance reasons you may wish to cache the result\n",
      " |      of calling this method if you don't expect the return value to change.\n",
      " |      \n",
      " |      Returns:\n",
      " |        A sequence of variables for the current module (sorted by attribute\n",
      " |        name) followed by variables from all submodules recursively (breadth\n",
      " |        first).\n",
      " |  \n",
      " |  outbound_nodes\n",
      " |      Deprecated, do NOT use! Only for compatibility with external Keras.\n",
      " |  \n",
      " |  output\n",
      " |      Retrieves the output tensor(s) of a layer.\n",
      " |      \n",
      " |      Only applicable if the layer has exactly one output,\n",
      " |      i.e. if it is connected to one incoming layer.\n",
      " |      \n",
      " |      Returns:\n",
      " |        Output tensor or list of output tensors.\n",
      " |      \n",
      " |      Raises:\n",
      " |        AttributeError: if the layer is connected to more than one incoming\n",
      " |          layers.\n",
      " |        RuntimeError: if called in Eager mode.\n",
      " |  \n",
      " |  output_mask\n",
      " |      Retrieves the output mask tensor(s) of a layer.\n",
      " |      \n",
      " |      Only applicable if the layer has exactly one inbound node,\n",
      " |      i.e. if it is connected to one incoming layer.\n",
      " |      \n",
      " |      Returns:\n",
      " |          Output mask tensor (potentially None) or list of output\n",
      " |          mask tensors.\n",
      " |      \n",
      " |      Raises:\n",
      " |          AttributeError: if the layer is connected to\n",
      " |          more than one incoming layers.\n",
      " |  \n",
      " |  output_shape\n",
      " |      Retrieves the output shape(s) of a layer.\n",
      " |      \n",
      " |      Only applicable if the layer has one output,\n",
      " |      or if all outputs have the same shape.\n",
      " |      \n",
      " |      Returns:\n",
      " |          Output shape, as an integer shape tuple\n",
      " |          (or list of shape tuples, one tuple per output tensor).\n",
      " |      \n",
      " |      Raises:\n",
      " |          AttributeError: if the layer has no defined output shape.\n",
      " |          RuntimeError: if called in Eager mode.\n",
      " |  \n",
      " |  stateful\n",
      " |  \n",
      " |  supports_masking\n",
      " |      Whether this layer supports computing a mask using `compute_mask`.\n",
      " |  \n",
      " |  trainable\n",
      " |  \n",
      " |  trainable_variables\n",
      " |      Sequence of trainable variables owned by this module and its submodules.\n",
      " |      \n",
      " |      Note: this method uses reflection to find variables on the current instance\n",
      " |      and submodules. For performance reasons you may wish to cache the result\n",
      " |      of calling this method if you don't expect the return value to change.\n",
      " |      \n",
      " |      Returns:\n",
      " |        A sequence of variables for the current module (sorted by attribute\n",
      " |        name) followed by variables from all submodules recursively (breadth\n",
      " |        first).\n",
      " |  \n",
      " |  updates\n",
      " |  \n",
      " |  variable_dtype\n",
      " |      Alias of `Layer.dtype`, the dtype of the weights.\n",
      " |  \n",
      " |  variables\n",
      " |      Returns the list of all layer variables/weights.\n",
      " |      \n",
      " |      Alias of `self.weights`.\n",
      " |      \n",
      " |      Note: This will not track the weights of nested `tf.Modules` that are not\n",
      " |      themselves Keras layers.\n",
      " |      \n",
      " |      Returns:\n",
      " |        A list of variables.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Class methods inherited from tensorflow.python.module.module.Module:\n",
      " |  \n",
      " |  with_name_scope(method) from builtins.type\n",
      " |      Decorator to automatically enter the module name scope.\n",
      " |      \n",
      " |      >>> class MyModule(tf.Module):\n",
      " |      ...   @tf.Module.with_name_scope\n",
      " |      ...   def __call__(self, x):\n",
      " |      ...     if not hasattr(self, 'w'):\n",
      " |      ...       self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))\n",
      " |      ...     return tf.matmul(x, self.w)\n",
      " |      \n",
      " |      Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose\n",
      " |      names included the module name:\n",
      " |      \n",
      " |      >>> mod = MyModule()\n",
      " |      >>> mod(tf.ones([1, 2]))\n",
      " |      <tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>\n",
      " |      >>> mod.w\n",
      " |      <tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,\n",
      " |      numpy=..., dtype=float32)>\n",
      " |      \n",
      " |      Args:\n",
      " |        method: The method to wrap.\n",
      " |      \n",
      " |      Returns:\n",
      " |        The original method wrapped such that it enters the module's name scope.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Data descriptors inherited from tensorflow.python.module.module.Module:\n",
      " |  \n",
      " |  name_scope\n",
      " |      Returns a `tf.name_scope` instance for this class.\n",
      " |  \n",
      " |  submodules\n",
      " |      Sequence of all sub-modules.\n",
      " |      \n",
      " |      Submodules are modules which are properties of this module, or found as\n",
      " |      properties of modules which are properties of this module (and so on).\n",
      " |      \n",
      " |      >>> a = tf.Module()\n",
      " |      >>> b = tf.Module()\n",
      " |      >>> c = tf.Module()\n",
      " |      >>> a.b = b\n",
      " |      >>> b.c = c\n",
      " |      >>> list(a.submodules) == [b, c]\n",
      " |      True\n",
      " |      >>> list(b.submodules) == [c]\n",
      " |      True\n",
      " |      >>> list(c.submodules) == []\n",
      " |      True\n",
      " |      \n",
      " |      Returns:\n",
      " |        A sequence of all submodules.\n",
      " |  \n",
      " |  ----------------------------------------------------------------------\n",
      " |  Data descriptors inherited from tensorflow.python.training.tracking.base.Trackable:\n",
      " |  \n",
      " |  __dict__\n",
      " |      dictionary for instance variables (if defined)\n",
      " |  \n",
      " |  __weakref__\n",
      " |      list of weak references to the object (if defined)\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\u001b[0;31mInit signature:\u001b[0m \u001b[0mtfdf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRandomForestModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
       "\u001b[0;31mDocstring:\u001b[0m     \n",
       "Random Forest learning algorithm.\n",
       "\n",
       "A Random Forest (https://www.stat.berkeley.edu/~breiman/randomforest2001.pdf)\n",
       "is a collection of deep CART decision trees trained independently and without\n",
       "pruning. Each tree is trained on a random subset of the original training \n",
       "dataset (sampled with replacement).\n",
       "\n",
       "The algorithm is unique in that it is robust to overfitting, even in extreme\n",
       "cases e.g. when there is more features than training examples.\n",
       "\n",
       "It is probably the most well-known of the Decision Forest training\n",
       "algorithms.\n",
       "\n",
       "Usage example:\n",
       "\n",
       "```python\n",
       "import tensorflow_decision_forests as tfdf\n",
       "import pandas as pd\n",
       "\n",
       "dataset = pd.read_csv(\"project/dataset.csv\")\n",
       "tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(dataset, label=\"my_label\")\n",
       "\n",
       "model = tfdf.keras.RandomForestModel()\n",
       "model.fit(tf_dataset)\n",
       "\n",
       "print(model.summary())\n",
       "```\n",
       "\n",
       "Attributes:\n",
       "  task: Task to solve (e.g. Task.CLASSIFICATION, Task.REGRESSION,\n",
       "    Task.RANKING).\n",
       "  features: Specify the list and semantic of the input features of the model.\n",
       "    If not specified, all the available features will be used. If specified\n",
       "    and if `exclude_non_specified_features=True`, only the features in\n",
       "    `features` will be used by the model. If \"preprocessing\" is used,\n",
       "    `features` corresponds to the output of the preprocessing. In this case,\n",
       "    it is recommended for the preprocessing to return a dictionary of tensors.\n",
       "  exclude_non_specified_features: If true, only use the features specified in\n",
       "    `features`.\n",
       "  preprocessing: Functional keras model or @tf.function to apply on the input\n",
       "    feature before the model to train. This preprocessing model can consume\n",
       "    and return tensors, list of tensors or dictionary of tensors. If\n",
       "    specified, the model only \"sees\" the output of the preprocessing (and not\n",
       "    the raw input). Can be used to prepare the features or to stack multiple\n",
       "    models on top of each other. Unlike preprocessing done in the tf.dataset,\n",
       "    the operation in \"preprocessing\" are serialized with the model.\n",
       "  postprocessing: Like \"preprocessing\" but applied on the model output.\n",
       "  ranking_group: Only for `task=Task.RANKING`. Name of a tf.string feature that\n",
       "    identifies queries in a query/document ranking task. The ranking group\n",
       "    is not added automatically for the set of features if\n",
       "    `exclude_non_specified_features=false`.\n",
       "  temp_directory: Temporary directory used to store the model Assets after the\n",
       "    training, and possibly as a work directory during the training. This\n",
       "    temporary directory is necessary for the model to be exported after\n",
       "    training e.g. `model.save(path)`. If not specified, `temp_directory` is\n",
       "    set to a temporary directory using `tempfile.TemporaryDirectory`. This\n",
       "    directory is deleted when the model python object is garbage-collected.\n",
       "  verbose: If true, displays information about the training.\n",
       "  hyperparameter_template: Override the default value of the hyper-parameters.\n",
       "    If None (default) the default parameters of the library are used. If set,\n",
       "    `default_hyperparameter_template` refers to one of the following\n",
       "    preconfigured hyper-parameter sets. Those sets outperforms the default\n",
       "    hyper-parameters (either generally or in specific scenarios).\n",
       "    You can omit the version (e.g. remove \"@v5\") to use the last version of\n",
       "    the template. In this case, the hyper-parameter can change in between\n",
       "    releases (not recommended for training in production).\n",
       "    - better_default@v1: A configuration that is generally better than the\n",
       "      default parameters without being more expensive. The parameters are:\n",
       "      winner_take_all=True.\n",
       "    - benchmark_rank1@v1: Top ranking hyper-parameters on our benchmark\n",
       "      slightly modified to run in reasonable time. The parameters are:\n",
       "      winner_take_all=True, categorical_algorithm=\"RANDOM\",\n",
       "      split_axis=\"SPARSE_OBLIQUE\", sparse_oblique_normalization=\"MIN_MAX\",\n",
       "      sparse_oblique_num_projections_exponent=1.0.\n",
       "\n",
       "  advanced_arguments: Advanced control of the model that most users won't need\n",
       "    to use. See `AdvancedArguments` for details.\n",
       "  num_threads: Number of threads used to train the model. Different learning\n",
       "    algorithms use multi-threading differently and with different degree of\n",
       "    efficiency. If specified, `num_threads` field of the\n",
       "    `advanced_arguments.yggdrasil_deployment_config` has priority.\n",
       "  name: The name of the model.\n",
       "  adapt_bootstrap_size_ratio_for_maximum_training_duration: Control how the\n",
       "    maximum training duration (if set) is applied. If false, the training\n",
       "    stop when the time is used. If true, adapts the size of the sampled\n",
       "    dataset used to train each tree such that `num_trees` will train within\n",
       "    `maximum_training_duration`. Has no effect if there is no maximum\n",
       "    training duration specified. Default: False.\n",
       "  allow_na_conditions: If true, the tree training evaluates conditions of the\n",
       "    type `X is NA` i.e. `X is missing`. Default: False.\n",
       "  categorical_algorithm: How to learn splits on categorical attributes.\n",
       "    - `CART`: CART algorithm. Find categorical splits of the form \"value \\\\in\n",
       "      mask\". The solution is exact for binary classification, regression and\n",
       "      ranking. It is approximated for multi-class classification. This is a\n",
       "      good first algorithm to use. In case of overfitting (very small\n",
       "      dataset, large dictionary), the \"random\" algorithm is a good\n",
       "      alternative.\n",
       "    - `ONE_HOT`: One-hot encoding. Find the optimal categorical split of the\n",
       "      form \"attribute == param\". This method is similar (but more efficient)\n",
       "      than converting converting each possible categorical value into a\n",
       "      boolean feature. This method is available for comparison purpose and\n",
       "      generally performs worse than other alternatives.\n",
       "    - `RANDOM`: Best splits among a set of random candidate. Find the a\n",
       "      categorical split of the form \"value \\\\in mask\" using a random search.\n",
       "      This solution can be seen as an approximation of the CART algorithm.\n",
       "      This method is a strong alternative to CART. This algorithm is inspired\n",
       "      from section \"5.1 Categorical Variables\" of \"Random Forest\", 2001.\n",
       "      Default: \"CART\".\n",
       "  categorical_set_split_greedy_sampling: For categorical set splits e.g.\n",
       "    texts. Probability for a categorical value to be a candidate for the\n",
       "    positive set. The sampling is applied once per node (i.e. not at every\n",
       "    step of the greedy optimization). Default: 0.1.\n",
       "  categorical_set_split_max_num_items: For categorical set splits e.g. texts.\n",
       "    Maximum number of items (prior to the sampling). If more items are\n",
       "    available, the least frequent items are ignored. Changing this value is\n",
       "    similar to change the \"max_vocab_count\" before loading the dataset, with\n",
       "    the following exception: With `max_vocab_count`, all the remaining items\n",
       "    are grouped in a special Out-of-vocabulary item. With `max_num_items`,\n",
       "    this is not the case. Default: -1.\n",
       "  categorical_set_split_min_item_frequency: For categorical set splits e.g.\n",
       "    texts. Minimum number of occurrences of an item to be considered.\n",
       "    Default: 1.\n",
       "  compute_oob_performances: If true, compute the Out-of-bag evaluation (then\n",
       "    available in the summary and model inspector). This evaluation is a cheap\n",
       "    alternative to cross-validation evaluation. Default: True.\n",
       "  compute_oob_variable_importances: If true, compute the Out-of-bag feature\n",
       "    importance (then available in the summary and model inspector). Note that\n",
       "    the OOB feature importance can be expensive to compute. Default: False.\n",
       "  growing_strategy: How to grow the tree.\n",
       "    - `LOCAL`: Each node is split independently of the other nodes. In other\n",
       "      words, as long as a node satisfy the splits \"constraints (e.g. maximum\n",
       "      depth, minimum number of observations), the node will be split. This is\n",
       "      the \"classical\" way to grow decision trees.\n",
       "    - `BEST_FIRST_GLOBAL`: The node with the best loss reduction among all\n",
       "      the nodes of the tree is selected for splitting. This method is also\n",
       "      called \"best first\" or \"leaf-wise growth\". See \"Best-first decision\n",
       "      tree learning\", Shi and \"Additive logistic regression : A statistical\n",
       "      view of boosting\", Friedman for more details. Default: \"LOCAL\".\n",
       "  in_split_min_examples_check: Whether to check the `min_examples` constraint\n",
       "    in the split search (i.e. splits leading to one child having less than\n",
       "    `min_examples` examples are considered invalid) or before the split\n",
       "    search (i.e. a node can be derived only if it contains more than\n",
       "    `min_examples` examples). If false, there can be nodes with less than\n",
       "    `min_examples` training examples. Default: True.\n",
       "  max_depth: Maximum depth of the tree. `max_depth=1` means that all trees\n",
       "    will be roots. Negative values are ignored. Default: 16.\n",
       "  max_num_nodes: Maximum number of nodes in the tree. Set to -1 to disable\n",
       "    this limit. Only available for `growing_strategy=BEST_FIRST_GLOBAL`.\n",
       "    Default: None.\n",
       "  maximum_training_duration_seconds: Maximum training duration of the model\n",
       "    expressed in seconds. Each learning algorithm is free to use this\n",
       "    parameter at it sees fit. Enabling maximum training duration makes the\n",
       "    model training non-deterministic. Default: -1.0.\n",
       "  min_examples: Minimum number of examples in a node. Default: 5.\n",
       "  missing_value_policy: Method used to handle missing attribute values.\n",
       "    - `GLOBAL_IMPUTATION`: Missing attribute values are imputed, with the\n",
       "      mean (in case of numerical attribute) or the most-frequent-item (in\n",
       "      case of categorical attribute) computed on the entire dataset (i.e. the\n",
       "      information contained in the data spec).\n",
       "    - `LOCAL_IMPUTATION`: Missing attribute values are imputed with the mean\n",
       "      (numerical attribute) or most-frequent-item (in the case of categorical\n",
       "      attribute) evaluated on the training examples in the current node.\n",
       "    - `RANDOM_LOCAL_IMPUTATION`: Missing attribute values are imputed from\n",
       "      randomly sampled values from the training examples in the current node.\n",
       "      This method was proposed by Clinic et al. in \"Random Survival Forests\"\n",
       "      (https://projecteuclid.org/download/pdfview_1/euclid.aoas/1223908043).\n",
       "      Default: \"GLOBAL_IMPUTATION\".\n",
       "  num_candidate_attributes: Number of unique valid attributes tested for each\n",
       "    node. An attribute is valid if it has at least a valid split. If\n",
       "    `num_candidate_attributes=0`, the value is set to the classical default\n",
       "    value for Random Forest: `sqrt(number of input attributes)` in case of\n",
       "    classification and `number_of_input_attributes / 3` in case of\n",
       "    regression. If `num_candidate_attributes=-1`, all the attributes are\n",
       "    tested. Default: 0.\n",
       "  num_candidate_attributes_ratio: Ratio of attributes tested at each node. If\n",
       "    set, it is equivalent to `num_candidate_attributes =\n",
       "    number_of_input_features x num_candidate_attributes_ratio`. The possible\n",
       "    values are between ]0, and 1] as well as -1. If not set or equal to -1,\n",
       "    the `num_candidate_attributes` is used. Default: -1.0.\n",
       "  num_trees: Number of individual decision trees. Increasing the number of\n",
       "    trees can increase the quality of the model at the expense of size,\n",
       "    training speed, and inference latency. Default: 300.\n",
       "  sorting_strategy: How are sorted the numerical features in order to find\n",
       "    the splits\n",
       "    - PRESORT: The features are pre-sorted at the start of the training. This\n",
       "      solution is faster but consumes much more memory than IN_NODE.\n",
       "    - IN_NODE: The features are sorted just before being used in the node.\n",
       "      This solution is slow but consumes little amount of memory.\n",
       "    . Default: \"PRESORT\".\n",
       "  sparse_oblique_normalization: For sparse oblique splits i.e.\n",
       "    `split_axis=SPARSE_OBLIQUE`. Normalization applied on the features,\n",
       "    before applying the sparse oblique projections.\n",
       "    - `NONE`: No normalization.\n",
       "    - `STANDARD_DEVIATION`: Normalize the feature by the estimated standard\n",
       "      deviation on the entire train dataset. Also known as Z-Score\n",
       "      normalization.\n",
       "    - `MIN_MAX`: Normalize the feature by the range (i.e. max-min) estimated\n",
       "      on the entire train dataset. Default: None.\n",
       "  sparse_oblique_num_projections_exponent: For sparse oblique splits i.e.\n",
       "    `split_axis=SPARSE_OBLIQUE`. Controls of the number of random projections\n",
       "    to test at each node as `num_features^num_projections_exponent`. Default:\n",
       "    None.\n",
       "  sparse_oblique_projection_density_factor: For sparse oblique splits i.e.\n",
       "    `split_axis=SPARSE_OBLIQUE`. Controls of the number of random projections\n",
       "    to test at each node as `num_features^num_projections_exponent`. Default:\n",
       "    None.\n",
       "  split_axis: What structure of split to consider for numerical features.\n",
       "    - `AXIS_ALIGNED`: Axis aligned splits (i.e. one condition at a time).\n",
       "      This is the \"classical\" way to train a tree. Default value.\n",
       "    - `SPARSE_OBLIQUE`: Sparse oblique splits (i.e. splits one a small number\n",
       "      of features) from \"Sparse Projection Oblique Random Forests\", Tomita et\n",
       "      al., 2020. Default: \"AXIS_ALIGNED\".\n",
       "  winner_take_all: Control how classification trees vote. If true, each tree\n",
       "    votes for one class. If false, each tree vote for a distribution of\n",
       "    classes. winner_take_all_inference=false is often preferable. Default:\n",
       "    True.\n",
       "\u001b[0;31mFile:\u001b[0m           /opt/conda/lib/python3.7/site-packages/tensorflow_decision_forests/keras/__init__.py\n",
       "\u001b[0;31mType:\u001b[0m           type\n",
       "\u001b[0;31mSubclasses:\u001b[0m     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# help works anywhere.\n",
    "help(tfdf.keras.RandomForestModel)\n",
    "\n",
    "# ? only works in ipython or notebooks, it usually opens on a separate panel.\n",
    "tfdf.keras.RandomForestModel?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PuWEYvXaiwhk"
   },
   "source": [
    "## Using a subset of features\n",
    "\n",
    "The previous example did not specify the features, so all the columns were used\n",
    "as input feature (except for the label). The following example shows how to\n",
    "specify input features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "id": "sgn_LnRz3M7z"
   },
   "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Use /tmp/tmp4swwgign as temporary training directory\n",
       "Starting reading the dataset\n",
       "1/1 [==============================] - ETA: 0sValidation data: ({'bill_depth_mm': [18.7 18 nan ... 13.5 13.4 14.6],\n",
       "  'bill_length_mm': [39.1 40.3 nan ... 46.5 43.3 45.8],\n",
       "  'body_mass_g': [3750 3250 nan ... 4550 4400 4200],\n",
       "  'flipper_length_mm': [181 195 nan ... 210 209 210],\n",
       "  'island': [\"Torgersen\" \"Torgersen\" \"Torgersen\" ... \"Biscoe\" \"Biscoe\" \"Biscoe\"],\n",
       "  'sex': [\"male\" \"female\" \"\" ... \"female\" \"female\" \"female\"],\n",
       "  'year': [2007 2007 2007 ... 2007 2007 2007]},\n",
       " [0 0 0 ... 1 1 1])\n",
       "\n",
       "Dataset read in 0:00:00.249375\n",
       "Training model\n",
       "Model trained in 0:00:00.059518\n",
       "Compiling model\n",
       "1/1 [==============================] - 0s 458ms/step - val_loss: 0.0000e+00 - val_accuracy: 0.9636\n",
       "1/1 [==============================] - 0s 78ms/step - loss: 0.0000e+00 - accuracy: 0.9636\n",
       "{'loss': 0.0, 'accuracy': 0.9636363387107849}\n"
      ]
     }
   ],
   "source": [
    "feature_1 = tfdf.keras.FeatureUsage(name=\"bill_length_mm\")\n",
    "feature_2 = tfdf.keras.FeatureUsage(name=\"island\")\n",
    "\n",
    "all_features = [feature_1, feature_2]\n",
    "\n",
    "# Note: This model is only trained with two features. It will not be as good as\n",
    "# the one trained on all features.\n",
    "# TODO\n",
    "\n",
    "model_2 = tfdf.keras.GradientBoostedTreesModel(\n",
    "    features=all_features, exclude_non_specified_features=True)\n",
    "\n",
    "model_2.compile(metrics=[\"accuracy\"])\n",
    "model_2.fit(x=train_ds, validation_data=test_ds)\n",
    "\n",
    "print(model_2.evaluate(test_ds, return_dict=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zvM84cgCmbUR"
   },
   "source": [
    "**Note:** As expected, the accuracy is lower than previously."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MFmqpivc7x7p"
   },
   "source": [
    "**TF-DF** attaches a **semantics** to each feature. This semantics controls how\n",
    "the feature is used by the model. The following semantics are currently supported:\n",
    "\n",
    "-   **Numerical**: Generally for quantities or counts with full ordering. For\n",
    "    example, the age of a person, or the number of items in a bag. Can be a\n",
    "    float or an integer. Missing values are represented with float(Nan) or with\n",
    "    an empty sparse tensor.\n",
    "-   **Categorical**: Generally for a type/class in finite set of possible values\n",
    "    without ordering. For example, the color RED in the set {RED, BLUE, GREEN}.\n",
    "    Can be a string or an integer. Missing values are represented as \"\" (empty\n",
    "    sting), value -2 or with an empty sparse tensor.\n",
    "-   **Categorical-Set**: A set of categorical values. Great to represent\n",
    "    tokenized text. Can be a string or an integer in a sparse tensor or a\n",
    "    ragged tensor (recommended). The order/index of each item doesn't matter.\n",
    "\n",
    "If not specified, the semantics is inferred from the representation type and shown in the training logs:\n",
    "\n",
    "- int, float (dense or sparse) → Numerical semantics.\n",
    "- str (dense or sparse) → Categorical semantics\n",
    "- int, str (ragged) → Categorical-Set semantics\n",
    "\n",
    "In some cases, the inferred semantics is incorrect. For example: An Enum stored as an integer is semantically categorical, but it will be detected as numerical. In this case, you should specify the semantic argument in the input. The `education_num` field of the Adult dataset is classical example.\n",
    "\n",
    "This dataset doesn't contain such a feature. However, for the demonstration, we will make the model treat the `year` as a categorical feature:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "id": "RNRIwLYC8zrp"
   },
   "outputs": [
     {
      "data": {
       "application/javascript": [
        "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})"
       ],
       "text/plain": [
        "<IPython.core.display.Javascript object>"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Use /tmp/tmp458m1n8i as temporary training directory\n",
       "Starting reading the dataset\n",
       "1/1 [==============================] - ETA: 0sValidation data: ({'bill_depth_mm': [18.7 18 nan ... 13.5 13.4 14.6],\n",
       "  'bill_length_mm': [39.1 40.3 nan ... 46.5 43.3 45.8],\n",
       "  'body_mass_g': [3750 3250 nan ... 4550 4400 4200],\n",
       "  'flipper_length_mm': [181 195 nan ... 210 209 210],\n",
       "  'island': [\"Torgersen\" \"Torgersen\" \"Torgersen\" ... \"Biscoe\" \"Biscoe\" \"Biscoe\"],\n",
       "  'sex': [\"male\" \"female\" \"\" ... \"female\" \"female\" \"female\"],\n",
       "  'year': [2007 2007 2007 ... 2007 2007 2007]},\n",
       " [0 0 0 ... 1 1 1])\n",
       "\n",
       "Dataset read in 0:00:00.178593\n",
       "Training model\n",
       "Model trained in 0:00:00.052186\n",
       "Compiling model\n",
       "1/1 [==============================] - 0s 386ms/step - val_loss: 0.0000e+00 - val_accuracy: 0.9273\n"
      ]
     }
   ],
   "source": [
    "# Define the features\n",
    "%set_cell_height 300\n",
    "\n",
    "feature_1 = tfdf.keras.FeatureUsage(name=\"year\", semantic=tfdf.keras.FeatureSemantic.CATEGORICAL)\n",
    "feature_2 = tfdf.keras.FeatureUsage(name=\"bill_length_mm\")\n",
    "feature_3 = tfdf.keras.FeatureUsage(name=\"sex\")\n",
    "all_features = [feature_1, feature_2, feature_3]\n",
    "\n",
    "model_3 = tfdf.keras.GradientBoostedTreesModel(features=all_features, exclude_non_specified_features=True)\n",
    "model_3.compile( metrics=[\"accuracy\"])\n",
    "\n",
    "with sys_pipes():\n",
    "  model_3.fit(x=train_ds, validation_data=test_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2AQaNwihcpP7"
   },
   "source": [
    "Note that `year` is in the list of CATEGORICAL features (unlike the first run)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GYrw7nKN40Vm"
   },
   "source": [
    "## Hyper-parameters\n",
    "\n",
    "**Hyper-parameters** are parameters of the training algorithm that impact\n",
    "the quality of the final model. They are specified in the model class\n",
    "constructor. The list of hyper-parameters is visible with the *question mark* colab command (e.g. `?tfdf.keras.GradientBoostedTreesModel`).\n",
    "\n",
    "Alternatively, you can find them on the [TensorFlow Decision Forest Github](https://github.com/tensorflow/decision-forests/keras/wrappers_pre_generated.py) or the [Yggdrasil Decision Forest documentation](https://github.com/google/yggdrasil_decision_forests/documentation/learners).\n",
    "\n",
    "The default hyper-parameters of each algorithm matches approximatively the initial publication paper. To ensure consistancy, new features and their matching hyper-parameters are always disable by default. That's why it is a good idea to tune your hyper-parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "id": "vHgPr4Pt43hv"
   },
   "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Use /tmp/tmp345y68ww as temporary training directory\n",
       "Starting reading the dataset\n",
       "1/1 [==============================] - ETA: 0s\n",
       "Dataset read in 0:00:00.114659\n",
       "Training model\n",
       "Model trained in 0:00:00.919768\n",
       "Compiling model\n",
       "1/1 [==============================] - 1s 1s/step\n"
      ]
     },
     {
      "data": {
       "text/plain": [
        "<keras.callbacks.History at 0x7eff544fb5d0>"
       ]
      },
      "execution_count": 28,
      "metadata": {},
      "output_type": "execute_result"
     }
   ],
   "source": [
    "# A classical but slighly more complex model.\n",
    "model_6 = tfdf.keras.GradientBoostedTreesModel(\n",
    "    num_trees=500, growing_strategy=\"BEST_FIRST_GLOBAL\", max_depth=8)\n",
    "model_6.fit(x=train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "id": "uECgPGDc2P4p"
   },
   "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Use /tmp/tmpjsftwr0c as temporary training directory\n",
       "Starting reading the dataset\n",
       "WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_train_function.<locals>.train_function at 0x7eff4c6d60e0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
      ]
     },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
       "WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_train_function.<locals>.train_function at 0x7eff4c6d60e0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
      ]
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "1/1 [==============================] - ETA: 0s\n",
       "Dataset read in 0:00:00.116099\n",
       "Training model\n",
       "Model trained in 0:00:00.503838\n",
       "Compiling model\n",
       "1/1 [==============================] - 1s 669ms/step\n",
       "WARNING:tensorflow:6 out of the last 6 calls to <function CoreModel.make_predict_function.<locals>.predict_function_trained at 0x7eff4c59ccb0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
      ]
     },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
       "WARNING:tensorflow:6 out of the last 6 calls to <function CoreModel.make_predict_function.<locals>.predict_function_trained at 0x7eff4c59ccb0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
      ]
     },
     {
      "data": {
       "text/plain": [
        "<keras.callbacks.History at 0x7eff4c6cf050>"
       ]
      },
      "execution_count": 29,
      "metadata": {},
      "output_type": "execute_result"
     }
   ],
   "source": [
    "# TODO\n",
    "# A more complex, but possibly, more accurate model.\n",
    "model_7 = tfdf.keras.GradientBoostedTreesModel(\n",
    "    num_trees=500,\n",
    "    growing_strategy=\"BEST_FIRST_GLOBAL\",\n",
    "    max_depth=8,\n",
    "    split_axis=\"SPARSE_OBLIQUE\",\n",
    "    categorical_algorithm=\"RANDOM\",\n",
    "    )\n",
    "model_7.fit(x=train_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Xk7wEmUZu3V0"
   },
   "source": [
    "As new training methods are published and implemented, combinaisons of hyper-parameters can emerge as good or almost-always-better than the default parameters. To avoid changing the default hyper-parameter values these good combinaisons are indexed and available as hyper-parameter templates.\n",
    "\n",
    "For example, the `benchmark_rank1` template is the best combinaison on our internal benchmarks. Those templates are versioned to allow training configuration stability e.g. `benchmark_rank1@v1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "id": "LtrRhMhj3hSu"
   },
   "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Resolve hyper-parameter template \"benchmark_rank1\" to \"benchmark_rank1@v1\" -> {'growing_strategy': 'BEST_FIRST_GLOBAL', 'categorical_algorithm': 'RANDOM', 'split_axis': 'SPARSE_OBLIQUE', 'sparse_oblique_normalization': 'MIN_MAX', 'sparse_oblique_num_projections_exponent': 1.0}.\n",
       "Use /tmp/tmppp0kyhjd as temporary training directory\n",
       "Starting reading the dataset\n",
       "1/1 [==============================] - ETA: 0s\n",
       "Dataset read in 0:00:00.115168\n",
       "Training model\n",
       "Model trained in 0:00:00.104833\n",
       "Compiling model\n",
       "1/1 [==============================] - 0s 245ms/step\n"
      ]
     },
     {
      "data": {
       "text/plain": [
        "<keras.callbacks.History at 0x7eff544fb510>"
       ]
      },
      "execution_count": 30,
      "metadata": {},
      "output_type": "execute_result"
     }
   ],
   "source": [
    "# A good template of hyper-parameters.\n",
    "model_8 = tfdf.keras.GradientBoostedTreesModel(hyperparameter_template=\"benchmark_rank1\")\n",
    "model_8.fit(x=train_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FSDXcKXB3u6M"
   },
   "source": [
    "The available tempaltes are available with `predefined_hyperparameters`. Note that different learning algorithms have different templates, even if the name is similar."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "id": "MQrWI2iv37Bo"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[HyperParameterTemplate(name='better_default', version=1, parameters={'growing_strategy': 'BEST_FIRST_GLOBAL'}, description='A configuration that is generally better than the default parameters without being more expensive.'), HyperParameterTemplate(name='benchmark_rank1', version=1, parameters={'growing_strategy': 'BEST_FIRST_GLOBAL', 'categorical_algorithm': 'RANDOM', 'split_axis': 'SPARSE_OBLIQUE', 'sparse_oblique_normalization': 'MIN_MAX', 'sparse_oblique_num_projections_exponent': 1.0}, description='Top ranking hyper-parameters on our benchmark slightly modified to run in reasonable time.')]\n"
     ]
    }
   ],
   "source": [
    "# The hyper-parameter templates of the Gradient Boosted Tree model.\n",
    "print(tfdf.keras.GradientBoostedTreesModel.predefined_hyperparameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gcX4tov1_lwp"
   },
   "source": [
    "## Feature Preprocessing\n",
    "\n",
    "Pre-processing features is sometimes necessary to consume signals with complex\n",
    "structures, to regularize the model or to apply transfer learning.\n",
    "Pre-processing can be done in one of three ways:\n",
    "\n",
    "1.  Preprocessing on the Pandas dataframe. This solution is easy to implement\n",
    "    and generally suitable for experimentation. However, the\n",
    "    pre-processing logic will not be exported in the model by `model.save()`.\n",
    "\n",
    "2.  [Keras Preprocessing](https://keras.io/guides/preprocessing_layers/): While\n",
    "    more complex than the previous solution, Keras Preprocessing is packaged in\n",
    "    the model.\n",
    "\n",
    "3.  [TensorFlow Feature Columns](https://www.tensorflow.org/tutorials/structured_data/feature_columns):\n",
    "    This API is part of the TF Estimator library (!= Keras) and planned for\n",
    "    deprecation. This solution is interesting when using existing preprocessing\n",
    "    code.\n",
    "\n",
    "Note: Using [TensorFlow Hub](https://www.tensorflow.org/hub)\n",
    "pre-trained embedding is often, a great way to consume text and image with\n",
    "TF-DF. For example, `hub.KerasLayer(\"https://tfhub.dev/google/nnlm-en-dim128/2\")`. See the [Intermediate tutorial](intermediate_colab.ipynb) for more details.\n",
    "\n",
    "In the next example, pre-process the `body_mass_g` feature into `body_mass_kg = body_mass_g / 1000`. The `bill_length_mm` is consumed without pre-processing. Note that such\n",
    "monotonic transformations have generally no impact on decision forest models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "id": "tGcIvTeKAApp"
   },
   "outputs": [
     {
      "data": {
       "application/javascript": [
        "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})"
       ],
       "text/plain": [
        "<IPython.core.display.Javascript object>"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Use /tmp/tmp8l_kc1rc as temporary training directory\n",
       "Starting reading the dataset\n",
       "1/1 [==============================] - ETA: 0s\n",
       "Dataset read in 0:00:00.188831\n",
       "Training model\n",
       "Model trained in 0:00:00.011874\n",
       "Compiling model\n",
       "1/1 [==============================] - 0s 209ms/step\n"
      ]
     },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
       "/opt/conda/lib/python3.7/site-packages/keras/engine/functional.py:559: UserWarning: Input dict contained keys ['island', 'bill_depth_mm', 'flipper_length_mm', 'sex', 'year'] which did not match any model input. They will be ignored by the model.\n",
       "  inputs = self._flatten_to_reference_inputs(inputs)\n"
      ]
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Model: \"random_forest_model_1\"\n",
       "_________________________________________________________________\n",
       " Layer (type)                Output Shape              Param #   \n",
       "=================================================================\n",
       " model (Functional)          {'body_mass_kg': (None,   0         \n",
       "                             1),                                 \n",
       "                              'bill_length_mm': (None            \n",
       "                             , 1)}                               \n",
       "                                                                 \n",
       "=================================================================\n",
       "Total params: 1\n",
       "Trainable params: 0\n",
       "Non-trainable params: 1\n",
       "_________________________________________________________________\n",
       "Type: \"RANDOM_FOREST\"\n",
       "Task: CLASSIFICATION\n",
       "Label: \"__LABEL\"\n",
       "\n",
       "Input Features (2):\n",
       "\tbill_length_mm\n",
       "\tbody_mass_kg\n",
       "\n",
       "No weights\n",
       "\n",
       "Variable Importance: MEAN_MIN_DEPTH:\n",
       "    1.        \"__LABEL\"  1.897333 ################\n",
       "    2. \"bill_length_mm\"  0.852389 ####\n",
       "    3.   \"body_mass_kg\"  0.381333 \n",
       "\n",
       "Variable Importance: NUM_AS_ROOT:\n",
       "    1.   \"body_mass_kg\" 215.000000 ################\n",
       "    2. \"bill_length_mm\" 85.000000 \n",
       "\n",
       "Variable Importance: NUM_NODES:\n",
       "    1. \"bill_length_mm\" 399.000000 ################\n",
       "    2.   \"body_mass_kg\" 364.000000 \n",
       "\n",
       "Variable Importance: SUM_SCORE:\n",
       "    1.   \"body_mass_kg\" 6285.188379 ################\n",
       "    2. \"bill_length_mm\" 3154.026463 \n",
       "\n",
       "\n",
       "\n",
       "Winner take all: true\n",
       "Out-of-bag evaluation: accuracy:0.947368 logloss:0.386684\n",
       "Number of trees: 300\n",
       "Total number of nodes: 1826\n",
       "\n",
       "Number of nodes by tree:\n",
       "Count: 300 Average: 6.08667 StdDev: 1.8706\n",
       "Min: 3 Max: 11 Ignored: 0\n",
       "----------------------------------------------\n",
       "[  3,  4)  41  13.67%  13.67% ####\n",
       "[  4,  5)   0   0.00%  13.67%\n",
       "[  5,  6) 105  35.00%  48.67% ##########\n",
       "[  6,  7)   0   0.00%  48.67%\n",
       "[  7,  8) 107  35.67%  84.33% ##########\n",
       "[  8,  9)   0   0.00%  84.33%\n",
       "[  9, 10)  44  14.67%  99.00% ####\n",
       "[ 10, 11)   0   0.00%  99.00%\n",
       "[ 11, 11]   3   1.00% 100.00%\n",
       "\n",
       "Depth by leafs:\n",
       "Count: 1063 Average: 2.02728 StdDev: 0.813924\n",
       "Min: 1 Max: 5 Ignored: 0\n",
       "----------------------------------------------\n",
       "[ 1, 2) 269  25.31%  25.31% #####\n",
       "[ 2, 3) 563  52.96%  78.27% ##########\n",
       "[ 3, 4) 166  15.62%  93.89% ###\n",
       "[ 4, 5)  63   5.93%  99.81% #\n",
       "[ 5, 5]   2   0.19% 100.00%\n",
       "\n",
       "Number of training obs by leaf:\n",
       "Count: 1063 Average: 32.1731 StdDev: 39.7175\n",
       "Min: 5 Max: 109 Ignored: 0\n",
       "----------------------------------------------\n",
       "[   5,  10) 625  58.80%  58.80% ##########\n",
       "[  10,  15) 125  11.76%  70.56% ##\n",
       "[  15,  20)  13   1.22%  71.78%\n",
       "[  20,  26)   0   0.00%  71.78%\n",
       "[  26,  31)   0   0.00%  71.78%\n",
       "[  31,  36)   0   0.00%  71.78%\n",
       "[  36,  41)   0   0.00%  71.78%\n",
       "[  41,  47)   0   0.00%  71.78%\n",
       "[  47,  52)   0   0.00%  71.78%\n",
       "[  52,  57)   0   0.00%  71.78%\n",
       "[  57,  62)   0   0.00%  71.78%\n",
       "[  62,  68)   0   0.00%  71.78%\n",
       "[  68,  73)   0   0.00%  71.78%\n",
       "[  73,  78)   3   0.28%  72.06%\n",
       "[  78,  83)  18   1.69%  73.75%\n",
       "[  83,  89)  26   2.45%  76.20%\n",
       "[  89,  94)  68   6.40%  82.60% #\n",
       "[  94,  99)  93   8.75%  91.35% #\n",
       "[  99, 104)  50   4.70%  96.05% #\n",
       "[ 104, 109]  42   3.95% 100.00% #\n",
       "\n",
       "Attribute in nodes:\n",
       "\t399 : bill_length_mm [NUMERICAL]\n",
       "\t364 : body_mass_kg [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 0:\n",
       "\t215 : body_mass_kg [NUMERICAL]\n",
       "\t85 : bill_length_mm [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 1:\n",
       "\t320 : body_mass_kg [NUMERICAL]\n",
       "\t311 : bill_length_mm [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 2:\n",
       "\t373 : bill_length_mm [NUMERICAL]\n",
       "\t357 : body_mass_kg [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 3:\n",
       "\t398 : bill_length_mm [NUMERICAL]\n",
       "\t364 : body_mass_kg [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 5:\n",
       "\t399 : bill_length_mm [NUMERICAL]\n",
       "\t364 : body_mass_kg [NUMERICAL]\n",
       "\n",
       "Condition type in nodes:\n",
       "\t763 : HigherCondition\n",
       "Condition type in nodes with depth <= 0:\n",
       "\t300 : HigherCondition\n",
       "Condition type in nodes with depth <= 1:\n",
       "\t631 : HigherCondition\n",
       "Condition type in nodes with depth <= 2:\n",
       "\t730 : HigherCondition\n",
       "Condition type in nodes with depth <= 3:\n",
       "\t762 : HigherCondition\n",
       "Condition type in nodes with depth <= 5:\n",
       "\t763 : HigherCondition\n",
       "Node format: NOT_SET\n",
       "\n",
       "Training OOB:\n",
       "\ttrees: 1, Out-of-bag evaluation: accuracy:0.918919 logloss:2.92246\n",
       "\ttrees: 11, Out-of-bag evaluation: accuracy:0.946903 logloss:1.60712\n",
       "\ttrees: 21, Out-of-bag evaluation: accuracy:0.95614 logloss:0.389124\n",
       "\ttrees: 31, Out-of-bag evaluation: accuracy:0.938596 logloss:0.411775\n",
       "\ttrees: 41, Out-of-bag evaluation: accuracy:0.938596 logloss:0.401537\n",
       "\ttrees: 51, Out-of-bag evaluation: accuracy:0.938596 logloss:0.401317\n",
       "\ttrees: 61, Out-of-bag evaluation: accuracy:0.95614 logloss:0.398354\n",
       "\ttrees: 72, Out-of-bag evaluation: accuracy:0.947368 logloss:0.404759\n",
       "\ttrees: 82, Out-of-bag evaluation: accuracy:0.95614 logloss:0.402354\n",
       "\ttrees: 92, Out-of-bag evaluation: accuracy:0.947368 logloss:0.400735\n",
       "\ttrees: 102, Out-of-bag evaluation: accuracy:0.95614 logloss:0.396863\n",
       "\ttrees: 112, Out-of-bag evaluation: accuracy:0.95614 logloss:0.393184\n",
       "\ttrees: 122, Out-of-bag evaluation: accuracy:0.947368 logloss:0.391585\n",
       "\ttrees: 132, Out-of-bag evaluation: accuracy:0.95614 logloss:0.391131\n",
       "\ttrees: 142, Out-of-bag evaluation: accuracy:0.95614 logloss:0.389081\n",
       "\ttrees: 152, Out-of-bag evaluation: accuracy:0.95614 logloss:0.389361\n",
       "\ttrees: 162, Out-of-bag evaluation: accuracy:0.95614 logloss:0.389358\n",
       "\ttrees: 172, Out-of-bag evaluation: accuracy:0.95614 logloss:0.389286\n",
       "\ttrees: 182, Out-of-bag evaluation: accuracy:0.95614 logloss:0.389254\n",
       "\ttrees: 192, Out-of-bag evaluation: accuracy:0.947368 logloss:0.389584\n",
       "\ttrees: 202, Out-of-bag evaluation: accuracy:0.95614 logloss:0.388841\n",
       "\ttrees: 212, Out-of-bag evaluation: accuracy:0.947368 logloss:0.389407\n",
       "\ttrees: 222, Out-of-bag evaluation: accuracy:0.947368 logloss:0.389675\n",
       "\ttrees: 232, Out-of-bag evaluation: accuracy:0.947368 logloss:0.388398\n",
       "\ttrees: 242, Out-of-bag evaluation: accuracy:0.947368 logloss:0.388136\n",
       "\ttrees: 252, Out-of-bag evaluation: accuracy:0.947368 logloss:0.387012\n",
       "\ttrees: 262, Out-of-bag evaluation: accuracy:0.947368 logloss:0.387445\n",
       "\ttrees: 272, Out-of-bag evaluation: accuracy:0.947368 logloss:0.386512\n",
       "\ttrees: 282, Out-of-bag evaluation: accuracy:0.947368 logloss:0.386564\n",
       "\ttrees: 292, Out-of-bag evaluation: accuracy:0.947368 logloss:0.386833\n",
       "\ttrees: 300, Out-of-bag evaluation: accuracy:0.947368 logloss:0.386684\n",
       "\n"
      ]
     }
   ],
   "source": [
    "%set_cell_height 300\n",
    "\n",
    "body_mass_g = tf.keras.layers.Input(shape=(1,), name=\"body_mass_g\")\n",
    "body_mass_kg = body_mass_g / 1000.0\n",
    "\n",
    "bill_length_mm = tf.keras.layers.Input(shape=(1,), name=\"bill_length_mm\")\n",
    "\n",
    "raw_inputs = {\"body_mass_g\": body_mass_g, \"bill_length_mm\": bill_length_mm}\n",
    "processed_inputs = {\"body_mass_kg\": body_mass_kg, \"bill_length_mm\": bill_length_mm}\n",
    "\n",
    "# \"preprocessor\" contains the preprocessing logic.\n",
    "preprocessor = tf.keras.Model(inputs=raw_inputs, outputs=processed_inputs)\n",
    "\n",
    "# \"model_4\" contains both the pre-processing logic and the decision forest.\n",
    "model_4 = tfdf.keras.RandomForestModel(preprocessing=preprocessor)\n",
    "model_4.fit(x=train_ds)\n",
    "\n",
    "model_4.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "h1Bx3Feyjb2o"
   },
   "source": [
    "The following example re-implements the same logic using TensorFlow Feature\n",
    "Columns."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "id": "fnwe3sBt-yJk"
   },
   "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Use /tmp/tmpdluydvpv as temporary training directory\n",
       "Starting reading the dataset\n",
       "1/1 [==============================] - ETA: 0s\n",
       "Dataset read in 0:00:00.101812\n",
       "Training model\n",
       "Model trained in 0:00:00.013404\n",
       "Compiling model\n",
       "1/1 [==============================] - 0s 124ms/step\n"
      ]
     },
     {
      "data": {
       "text/plain": [
        "<keras.callbacks.History at 0x7eff54dacc90>"
       ]
      },
      "execution_count": 33,
      "metadata": {},
      "output_type": "execute_result"
     }
   ],
   "source": [
    "def g_to_kg(x):\n",
    "  return x / 1000\n",
    "\n",
    "feature_columns = [\n",
    "    tf.feature_column.numeric_column(\"body_mass_g\", normalizer_fn=g_to_kg),\n",
    "    tf.feature_column.numeric_column(\"bill_length_mm\"),\n",
    "]\n",
    "\n",
    "preprocessing = tf.keras.layers.DenseFeatures(feature_columns)\n",
    "\n",
    "model_5 = tfdf.keras.RandomForestModel(preprocessing=preprocessing)\n",
    "model_5.compile(metrics=[\"accuracy\"])\n",
    "model_5.fit(x=train_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9vif6gsAjfzv"
   },
   "source": [
    "## Training a regression model\n",
    "\n",
    "The previous example trains a classification model (TF-DF does not differentiate\n",
    "between binary classification and multi-class classification). In the next\n",
    "example, train a regression model on the\n",
    "[Abalone dataset](https://archive.ics.uci.edu/ml/datasets/abalone). The\n",
    "objective of this dataset is to predict the number of shell's rings of an\n",
    "abalone.\n",
    "\n",
    "**Note:** The csv file is assembled by appending UCI's header and data files. No preprocessing was applied.\n",
    "\n",
    "<center>\n",
    "<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/3/33/LivingAbalone.JPG/800px-LivingAbalone.JPG\" width=\"200\"/></center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "id": "0uKI_Uy7RyWN"
   },
   "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Copying gs://cloud-training/mlongcp/v3.0_MLonGC/toy_data/abalone_raw_toy.csv...\n",
       "/ [1 files][ 89.8 KiB/ 89.8 KiB]                                                \n",
       "Operation completed over 1 objects/89.8 KiB.                                     \n",
       "  Type  LongestShell  Diameter  Height  WholeWeight  ShuckedWeight  \\\n",
       "0    M         0.455     0.365   0.095       0.5140         0.2245   \n",
       "1    M         0.350     0.265   0.090       0.2255         0.0995   \n",
       "2    F         0.530     0.420   0.135       0.6770         0.2565   \n",
       "\n",
       "   VisceraWeight  ShellWeight  Rings  \n",
       "0         0.1010         0.15     15  \n",
       "1         0.0485         0.07      7  \n",
       "2         0.1415         0.21      9  \n"
      ]
     }
   ],
   "source": [
    "# Download the dataset.\n",
    "!gcloud storage cp gs://cloud-training/mlongcp/v3.0_MLonGC/toy_data/abalone_raw_toy.csv /tmp/abalone.csv\n",    "\n",
    "dataset_df = pd.read_csv(\"/tmp/abalone.csv\")\n",
    "print(dataset_df.head(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "id": "_gjrquQySU7Q"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1413 examples in training, 586 examples for testing.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1224: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only\n",
      "  (dict(dataframe.drop(label, 1)), dataframe[label].values))\n"
     ]
    }
   ],
   "source": [
    "# Split the dataset into a training and testing dataset.\n",
    "train_ds_pd, test_ds_pd = split_dataset(dataset_df)\n",
    "print(\"{} examples in training, {} examples for testing.\".format(\n",
    "    len(train_ds_pd), len(test_ds_pd)))\n",
    "\n",
    "# Name of the label column.\n",
    "label = \"Rings\"\n",
    "\n",
    "train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label, task=tfdf.keras.Task.REGRESSION)\n",
    "test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=label, task=tfdf.keras.Task.REGRESSION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "id": "t8fUhQKISqYT"
   },
   "outputs": [
     {
      "data": {
       "application/javascript": [
        "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})"
       ],
       "text/plain": [
        "<IPython.core.display.Javascript object>"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Use /tmp/tmpw4c8_rps as temporary training directory\n",
       "Starting reading the dataset\n",
       "1/2 [==============>...............] - ETA: 0s\n",
       "Dataset read in 0:00:00.130998\n",
       "Training model\n",
       "Model trained in 0:00:00.572397\n",
       "Compiling model\n",
       "2/2 [==============================] - 1s 967ms/step\n"
      ]
     }
   ],
   "source": [
    "%set_cell_height 300\n",
    "\n",
    "# TODO\n",
    "# Configure the regression model.\n",
    "model_7 = tfdf.keras.RandomForestModel(task = tfdf.keras.Task.REGRESSION)\n",
    "\n",
    "# Optional.\n",
    "model_7.compile(metrics=[\"mse\"])\n",
    "\n",
    "# Train the model.\n",
    "with sys_pipes():\n",
    "  model_7.fit(x=train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "id": "aSriIAaMSzwA"
   },
   "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "WARNING:tensorflow:5 out of the last 5 calls to <function CoreModel.make_test_function.<locals>.test_function at 0x7eff54d46050> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
      ]
     },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
       "WARNING:tensorflow:5 out of the last 5 calls to <function CoreModel.make_test_function.<locals>.test_function at 0x7eff54d46050> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
      ]
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "2/2 [==============================] - 0s 18ms/step - loss: 0.0000e+00 - mse: 2.0869\n",
       "{'loss': 0.0, 'mse': 2.086915969848633}\n",
       "\n",
       "MSE: 2.086915969848633\n",
       "RMSE: 1.4446162015734951\n"
      ]
     }
   ],
   "source": [
    "# Evaluate the model on the test dataset.\n",
    "evaluation = model_7.evaluate(test_ds, return_dict=True)\n",
    "\n",
    "print(evaluation)\n",
    "print()\n",
    "print(f\"MSE: {evaluation['mse']}\")\n",
    "print(f\"RMSE: {math.sqrt(evaluation['mse'])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S54mR6i9jkhp"
   },
   "source": [
    "## Training a ranking model\n",
    "\n",
    "Finaly, after having trained a classification and a regression models, train a [ranking](https://en.wikipedia.org/wiki/Learning_to_rank) model.\n",
    "\n",
    "The goal of a ranking is to **order** items by importance. The \"value\" of\n",
    "relevance does not matter directly. Ranking a set of *documents* with regard to\n",
    "a user *query* is an example of ranking problem: It is only important to get the right order, where the top documents matter more.\n",
    "\n",
    "TF-DF expects for ranking datasets to be presented in a \"flat\" format. A\n",
    "document+query dataset might look like that:\n",
    "\n",
    "query | document_id | feature_1 | feature_2 | relevance/label\n",
    "----- | ----------- | --------- | --------- | ---------------\n",
    "cat   | 1           | 0.1       | blue      | 4\n",
    "cat   | 2           | 0.5       | green     | 1\n",
    "cat   | 3           | 0.2       | red       | 2\n",
    "dog   | 4           | NA        | red       | 0\n",
    "dog   | 5           | 0.2       | red       | 1\n",
    "dog   | 6           | 0.6       | green     | 1\n",
    "\n",
    "The *relevance/label* is a floating point numerical value between 0 and 5\n",
    "(generally between 0 and 4) where 0 means \"completely unrelated\", 4 means \"very\n",
    "relevant\" and 5 means \"the same as the query\".\n",
    "\n",
    "Interestingly, decision forests are often good rankers, and many\n",
    "state-of-the-art ranking models are decision forests.\n",
    "\n",
    "In this example, use a sample of the\n",
    "[LETOR3](https://www.microsoft.com/en-us/research/project/letor-learning-rank-information-retrieval/#!letor-3-0)\n",
    "dataset. More precisely, we want to download the `OHSUMED.zip` from [the LETOR3 repo](https://onedrive.live.com/?authkey=%21ACnoZZSZVfHPJd0&id=8FEADC23D838BDA8%21107&cid=8FEADC23D838BDA8). This dataset is stored in the\n",
    "libsvm format, so we will need to convert it to csv."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "id": "axD6x1ZivHCS"
   },
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "google.colab.output.setIframeHeight(0, true, {maxHeight: 200})"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://download.microsoft.com/download/E/7/E/E7EABEF1-4C7B-4E31-ACE5-73927950ED5E/Letor.zip\n",
      "61825024/61824018 [==============================] - 1s 0us/step\n",
      "61833216/61824018 [==============================] - 1s 0us/step\n"
     ]
    }
   ],
   "source": [
    "%set_cell_height 200\n",
    "\n",
    "archive_path = tf.keras.utils.get_file(\"letor.zip\",\n",
    "  \"https://download.microsoft.com/download/E/7/E/E7EABEF1-4C7B-4E31-ACE5-73927950ED5E/Letor.zip\",\n",
    "  extract=True)\n",
    "\n",
    "# Path to the train and test dataset using libsvm format.\n",
    "raw_dataset_path = os.path.join(os.path.dirname(archive_path),\"OHSUMED/Data/All/OHSUMED.txt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rcManr98ZGID"
   },
   "source": [
    "The dataset is stored as a .txt file in a specific format, so first convert it into a csv file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "id": "mkiM9HJox-e8"
   },
   "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Copying gs://cloud-training/mlongcp/v3.0_MLonGC/toy_data/ohsumed_toy.csv...\n",
       "/ [1 files][  1.9 MiB/  1.9 MiB]                                                \n",
       "Operation completed over 1 objects/1.9 MiB.                                      \n"
      ]
     },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>relevance</th>\n",
       "      <th>group</th>\n",
       "      <th>f_1</th>\n",
       "      <th>f_2</th>\n",
       "      <th>f_3</th>\n",
       "      <th>f_4</th>\n",
       "      <th>f_5</th>\n",
       "      <th>f_6</th>\n",
       "      <th>f_7</th>\n",
       "      <th>f_8</th>\n",
       "      <th>...</th>\n",
       "      <th>f_16</th>\n",
       "      <th>f_17</th>\n",
       "      <th>f_18</th>\n",
       "      <th>f_19</th>\n",
       "      <th>f_20</th>\n",
       "      <th>f_21</th>\n",
       "      <th>f_22</th>\n",
       "      <th>f_23</th>\n",
       "      <th>f_24</th>\n",
       "      <th>f_25</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>g_1</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2.079442</td>\n",
       "      <td>0.272727</td>\n",
       "      <td>0.261034</td>\n",
       "      <td>37.330565</td>\n",
       "      <td>11.431241</td>\n",
       "      <td>37.29975</td>\n",
       "      <td>1.138657</td>\n",
       "      <td>...</td>\n",
       "      <td>9.340024</td>\n",
       "      <td>24.808785</td>\n",
       "      <td>0.393091</td>\n",
       "      <td>57.416517</td>\n",
       "      <td>3.294893</td>\n",
       "      <td>25.0231</td>\n",
       "      <td>3.219799</td>\n",
       "      <td>-3.87098</td>\n",
       "      <td>-3.90273</td>\n",
       "      <td>-3.87512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>g_1</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2.079442</td>\n",
       "      <td>0.428571</td>\n",
       "      <td>0.400594</td>\n",
       "      <td>37.330565</td>\n",
       "      <td>11.431241</td>\n",
       "      <td>37.29975</td>\n",
       "      <td>1.814480</td>\n",
       "      <td>...</td>\n",
       "      <td>9.340024</td>\n",
       "      <td>24.808785</td>\n",
       "      <td>0.349205</td>\n",
       "      <td>43.240626</td>\n",
       "      <td>2.654724</td>\n",
       "      <td>23.4903</td>\n",
       "      <td>3.156588</td>\n",
       "      <td>-3.96838</td>\n",
       "      <td>-4.00865</td>\n",
       "      <td>-3.98670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>g_1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>37.330565</td>\n",
       "      <td>11.431241</td>\n",
       "      <td>37.29975</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>9.340024</td>\n",
       "      <td>24.808785</td>\n",
       "      <td>0.240319</td>\n",
       "      <td>25.816989</td>\n",
       "      <td>1.551342</td>\n",
       "      <td>15.8650</td>\n",
       "      <td>2.764115</td>\n",
       "      <td>-4.28166</td>\n",
       "      <td>-4.33313</td>\n",
       "      <td>-4.44161</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>3 rows × 27 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   relevance group  f_1       f_2       f_3       f_4        f_5        f_6  \\\n",
       "0          2   g_1  3.0  2.079442  0.272727  0.261034  37.330565  11.431241   \n",
       "1          0   g_1  3.0  2.079442  0.428571  0.400594  37.330565  11.431241   \n",
       "2          2   g_1  0.0  0.000000  0.000000  0.000000  37.330565  11.431241   \n",
       "\n",
       "        f_7       f_8  ...      f_16       f_17      f_18       f_19  \\\n",
       "0  37.29975  1.138657  ...  9.340024  24.808785  0.393091  57.416517   \n",
       "1  37.29975  1.814480  ...  9.340024  24.808785  0.349205  43.240626   \n",
       "2  37.29975  0.000000  ...  9.340024  24.808785  0.240319  25.816989   \n",
       "\n",
       "       f_20     f_21      f_22     f_23     f_24     f_25  \n",
       "0  3.294893  25.0231  3.219799 -3.87098 -3.90273 -3.87512  \n",
       "1  2.654724  23.4903  3.156588 -3.96838 -4.00865 -3.98670  \n",
       "2  1.551342  15.8650  2.764115 -4.28166 -4.33313 -4.44161  \n",
       "\n",
       "[3 rows x 27 columns]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def convert_libsvm_to_csv(src_path, dst_path):\n",
    "  \"\"\"Converts a libsvm ranking dataset into a flat csv file.\n",
    "  \n",
    "  Note: This code is specific to the LETOR3 dataset.\n",
    "  \"\"\"\n",
    "  dst_handle = open(dst_path, \"w\")\n",
    "  first_line = True\n",
    "  for src_line in open(src_path,\"r\"):\n",
    "    # Note: The last 3 items are comments.\n",
    "    items = src_line.split(\" \")[:-3]\n",
    "    relevance = items[0]\n",
    "    group = items[1].split(\":\")[1]\n",
    "    features = [ item.split(\":\") for item in items[2:]]\n",
    "\n",
    "    if first_line:\n",
    "      # Csv header\n",
    "      dst_handle.write(\"relevance,group,\" + \",\".join([\"f_\" + feature[0] for feature in features]) + \"\\n\")\n",
    "      first_line = False\n",
    "    dst_handle.write(relevance + \",g_\" + group + \",\" + (\",\".join([feature[1] for feature in features])) + \"\\n\")\n",
    "  dst_handle.close()\n",
    "\n",
    "# Convert the dataset.\n",
"!gcloud storage cp gs://cloud-training/mlongcp/v3.0_MLonGC/toy_data/ohsumed_toy.csv /tmp/ohsumed.csv\n",    "csv_dataset_path=\"/tmp/ohsumed.csv\"\n",
    "convert_libsvm_to_csv(raw_dataset_path, csv_dataset_path)\n",
    "\n",
    "# Load a dataset into a Pandas Dataframe.\n",
    "dataset_df = pd.read_csv(csv_dataset_path)\n",
    "\n",
    "# Display the first 3 examples.\n",
    "dataset_df.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "id": "wB7bWAja1G-o"
   },
   "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "11324 examples in training, 4816 examples for testing.\n"
      ]
     },
     {
      "data": {
       "text/html": [
        "<div>\n",
        "<style scoped>\n",
        "    .dataframe tbody tr th:only-of-type {\n",
        "        vertical-align: middle;\n",
        "    }\n",
        "\n",
        "    .dataframe tbody tr th {\n",
        "        vertical-align: top;\n",
        "    }\n",
        "\n",
        "    .dataframe thead th {\n",
        "        text-align: right;\n",
        "    }\n",
        "</style>\n",
        "<table border=\"1\" class=\"dataframe\">\n",
        "  <thead>\n",
        "    <tr style=\"text-align: right;\">\n",
        "      <th></th>\n",
        "      <th>relevance</th>\n",
        "      <th>group</th>\n",
        "      <th>f_1</th>\n",
        "      <th>f_2</th>\n",
        "      <th>f_3</th>\n",
        "      <th>f_4</th>\n",
        "      <th>f_5</th>\n",
        "      <th>f_6</th>\n",
        "      <th>f_7</th>\n",
        "      <th>f_8</th>\n",
        "      <th>...</th>\n",
        "      <th>f_16</th>\n",
        "      <th>f_17</th>\n",
        "      <th>f_18</th>\n",
        "      <th>f_19</th>\n",
        "      <th>f_20</th>\n",
        "      <th>f_21</th>\n",
        "      <th>f_22</th>\n",
        "      <th>f_23</th>\n",
        "      <th>f_24</th>\n",
        "      <th>f_25</th>\n",
        "    </tr>\n",
        "  </thead>\n",
        "  <tbody>\n",
        "    <tr>\n",
        "      <th>0</th>\n",
        "      <td>2</td>\n",
        "      <td>g_1</td>\n",
        "      <td>3.0</td>\n",
        "      <td>2.079442</td>\n",
        "      <td>0.272727</td>\n",
        "      <td>0.261034</td>\n",
        "      <td>37.330565</td>\n",
        "      <td>11.431241</td>\n",
        "      <td>37.29975</td>\n",
        "      <td>1.138657</td>\n",
        "      <td>...</td>\n",
        "      <td>9.340024</td>\n",
        "      <td>24.808785</td>\n",
        "      <td>0.393091</td>\n",
        "      <td>57.416517</td>\n",
        "      <td>3.294893</td>\n",
        "      <td>25.0231</td>\n",
        "      <td>3.219799</td>\n",
        "      <td>-3.87098</td>\n",
        "      <td>-3.90273</td>\n",
        "      <td>-3.87512</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>2</th>\n",
        "      <td>2</td>\n",
        "      <td>g_1</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.000000</td>\n",
        "      <td>0.000000</td>\n",
        "      <td>0.000000</td>\n",
        "      <td>37.330565</td>\n",
        "      <td>11.431241</td>\n",
        "      <td>37.29975</td>\n",
        "      <td>0.000000</td>\n",
        "      <td>...</td>\n",
        "      <td>9.340024</td>\n",
        "      <td>24.808785</td>\n",
        "      <td>0.240319</td>\n",
        "      <td>25.816989</td>\n",
        "      <td>1.551342</td>\n",
        "      <td>15.8650</td>\n",
        "      <td>2.764115</td>\n",
        "      <td>-4.28166</td>\n",
        "      <td>-4.33313</td>\n",
        "      <td>-4.44161</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>4</th>\n",
        "      <td>0</td>\n",
        "      <td>g_1</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.000000</td>\n",
        "      <td>0.000000</td>\n",
        "      <td>0.000000</td>\n",
        "      <td>37.330565</td>\n",
        "      <td>11.431241</td>\n",
        "      <td>37.29975</td>\n",
        "      <td>0.000000</td>\n",
        "      <td>...</td>\n",
        "      <td>9.340024</td>\n",
        "      <td>24.808785</td>\n",
        "      <td>0.182104</td>\n",
        "      <td>23.546296</td>\n",
        "      <td>1.621393</td>\n",
        "      <td>15.2764</td>\n",
        "      <td>2.726309</td>\n",
        "      <td>-4.43073</td>\n",
        "      <td>-4.45985</td>\n",
        "      <td>-4.57053</td>\n",
        "    </tr>\n",
        "  </tbody>\n",
        "</table>\n",
        "<p>3 rows × 27 columns</p>\n",
        "</div>"
       ],
       "text/plain": [
        "   relevance group  f_1       f_2       f_3       f_4        f_5        f_6  \\\n",
        "0          2   g_1  3.0  2.079442  0.272727  0.261034  37.330565  11.431241   \n",
        "2          2   g_1  0.0  0.000000  0.000000  0.000000  37.330565  11.431241   \n",
        "4          0   g_1  0.0  0.000000  0.000000  0.000000  37.330565  11.431241   \n",
        "\n",
        "        f_7       f_8  ...      f_16       f_17      f_18       f_19  \\\n",
        "0  37.29975  1.138657  ...  9.340024  24.808785  0.393091  57.416517   \n",
        "2  37.29975  0.000000  ...  9.340024  24.808785  0.240319  25.816989   \n",
        "4  37.29975  0.000000  ...  9.340024  24.808785  0.182104  23.546296   \n",
        "\n",
        "       f_20     f_21      f_22     f_23     f_24     f_25  \n",
        "0  3.294893  25.0231  3.219799 -3.87098 -3.90273 -3.87512  \n",
        "2  1.551342  15.8650  2.764115 -4.28166 -4.33313 -4.44161  \n",
        "4  1.621393  15.2764  2.726309 -4.43073 -4.45985 -4.57053  \n",
        "\n",
        "[3 rows x 27 columns]"
       ]
      },
      "execution_count": 40,
      "metadata": {},
      "output_type": "execute_result"
     }
   ],
   "source": [
    "train_ds_pd, test_ds_pd = split_dataset(dataset_df)\n",
    "print(\"{} examples in training, {} examples for testing.\".format(\n",
    "    len(train_ds_pd), len(test_ds_pd)))\n",
    "\n",
    "# Display the first 3 examples of the training dataset.\n",
    "train_ds_pd.head(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YQKqN9zN4L00"
   },
   "source": [
    "In this dataset, the `relevance` defines the ground-truth rank among rows of the same `group`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "id": "5QMbBkCEXxu_"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1224: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only\n",
      "  (dict(dataframe.drop(label, 1)), dataframe[label].values))\n"
     ]
    }
   ],
   "source": [
    "# Name of the relevance and grouping columns.\n",
    "relevance = \"relevance\"\n",
    "\n",
    "ranking_train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=relevance, task=tfdf.keras.Task.RANKING)\n",
    "ranking_test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=relevance, task=tfdf.keras.Task.RANKING)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "id": "Ba1gb75SX1rr"
   },
   "outputs": [
     {
      "data": {
       "application/javascript": [
        "google.colab.output.setIframeHeight(0, true, {maxHeight: 400})"
       ],
       "text/plain": [
        "<IPython.core.display.Javascript object>"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Use /tmp/tmpg8di6o4f as temporary training directory\n",
       "Starting reading the dataset\n",
       " 9/12 [=====================>........] - ETA: 0s\n",
       "Dataset read in 0:00:00.708221\n",
       "Training model\n",
       "Model trained in 0:00:01.935615\n",
       "Compiling model\n",
       "12/12 [==============================] - 3s 191ms/step\n"
      ]
     }
   ],
   "source": [
    "%set_cell_height 400\n",
    "\n",
    "# TODO\n",
    "# Define the ranking model\n",
    "model_8 = tfdf.keras.GradientBoostedTreesModel(\n",
    "    task=tfdf.keras.Task.RANKING,\n",
    "    ranking_group=\"group\",\n",
    "    num_trees=50)\n",
    "\n",
    "with sys_pipes():\n",
    "  model_8.fit(x=ranking_train_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "spZCfxfR3VK0"
   },
   "source": [
    "At this point, keras does not propose any ranking metrics. Instead, the training and validation (a GBDT uses a validation dataset) are shown in the training\n",
    "logs. In this case the loss is `LAMBDA_MART_NDCG5`, and the final (i.e. at\n",
    "the end of the training) NDCG (normalized discounted cumulative gain) is `0.510136` (see line `Final model valid-loss: -0.510136`).\n",
    "\n",
    "Note that the NDCG is a value between 0 and 1. The larget the NDCG, the better\n",
    "the model. For this reason, the loss to be -NDCG.\n",
    "\n",
    "As before, the model can be analysed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "id": "L4N1R8fM4jFh"
   },
   "outputs": [
     {
      "data": {
       "application/javascript": [
        "google.colab.output.setIframeHeight(0, true, {maxHeight: 400})"
       ],
       "text/plain": [
        "<IPython.core.display.Javascript object>"
       ]
      },
      "metadata": {},
      "output_type": "display_data"
     },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Model: \"gradient_boosted_trees_model_6\"\n",
       "_________________________________________________________________\n",
       " Layer (type)                Output Shape              Param #   \n",
       "=================================================================\n",
       "=================================================================\n",
       "Total params: 1\n",
       "Trainable params: 0\n",
       "Non-trainable params: 1\n",
       "_________________________________________________________________\n",
       "Type: \"GRADIENT_BOOSTED_TREES\"\n",
       "Task: RANKING\n",
       "Label: \"__LABEL\"\n",
       "Rank group: \"__RANK_GROUP\"\n",
       "\n",
       "Input Features (25):\n",
       "\tf_1\n",
       "\tf_10\n",
       "\tf_11\n",
       "\tf_12\n",
       "\tf_13\n",
       "\tf_14\n",
       "\tf_15\n",
       "\tf_16\n",
       "\tf_17\n",
       "\tf_18\n",
       "\tf_19\n",
       "\tf_2\n",
       "\tf_20\n",
       "\tf_21\n",
       "\tf_22\n",
       "\tf_23\n",
       "\tf_24\n",
       "\tf_25\n",
       "\tf_3\n",
       "\tf_4\n",
       "\tf_5\n",
       "\tf_6\n",
       "\tf_7\n",
       "\tf_8\n",
       "\tf_9\n",
       "\n",
       "No weights\n",
       "\n",
       "Variable Importance: MEAN_MIN_DEPTH:\n",
       "    1. \"__RANK_GROUP\"  4.474058 ################\n",
       "    2.      \"__LABEL\"  4.474058 ################\n",
       "    3.          \"f_5\"  4.460133 ###############\n",
       "    4.          \"f_1\"  4.435530 ###############\n",
       "    5.         \"f_18\"  4.429623 ###############\n",
       "    6.         \"f_13\"  4.428001 ###############\n",
       "    7.          \"f_7\"  4.422496 ###############\n",
       "    8.          \"f_6\"  4.410153 ###############\n",
       "    9.         \"f_14\"  4.400535 ###############\n",
       "   10.         \"f_17\"  4.397343 ###############\n",
       "   11.         \"f_11\"  4.369202 ###############\n",
       "   12.          \"f_9\"  4.352847 ###############\n",
       "   13.         \"f_15\"  4.339701 ###############\n",
       "   14.          \"f_4\"  4.290652 ###############\n",
       "   15.          \"f_2\"  4.273807 ###############\n",
       "   16.         \"f_24\"  4.271796 ###############\n",
       "   17.         \"f_20\"  4.243528 ##############\n",
       "   18.         \"f_10\"  4.207370 ##############\n",
       "   19.         \"f_25\"  4.171730 ##############\n",
       "   20.         \"f_12\"  4.149826 ##############\n",
       "   21.         \"f_19\"  4.057388 #############\n",
       "   22.         \"f_16\"  3.896911 #############\n",
       "   23.         \"f_22\"  3.721394 ############\n",
       "   24.         \"f_21\"  3.647692 ###########\n",
       "   25.         \"f_23\"  3.227891 #########\n",
       "   26.          \"f_3\"  2.745987 #######\n",
       "   27.          \"f_8\"  1.208292 \n",
       "\n",
       "Variable Importance: NUM_AS_ROOT:\n",
       "    1.  \"f_8\" 13.000000 ################\n",
       "    2. \"f_23\"  4.000000 ####\n",
       "    3.  \"f_3\"  3.000000 ##\n",
       "    4. \"f_10\"  1.000000 \n",
       "    5. \"f_12\"  1.000000 \n",
       "\n",
       "Variable Importance: NUM_NODES:\n",
       "    1.  \"f_8\" 58.000000 ################\n",
       "    2.  \"f_3\" 33.000000 ########\n",
       "    3. \"f_19\" 30.000000 #######\n",
       "    4. \"f_23\" 29.000000 #######\n",
       "    5. \"f_25\" 23.000000 #####\n",
       "    6. \"f_21\" 20.000000 ####\n",
       "    7. \"f_22\" 19.000000 ####\n",
       "    8. \"f_16\" 18.000000 ####\n",
       "    9. \"f_20\" 18.000000 ####\n",
       "   10.  \"f_4\" 17.000000 ####\n",
       "   11. \"f_24\" 16.000000 ###\n",
       "   12. \"f_10\" 11.000000 ##\n",
       "   13. \"f_12\" 11.000000 ##\n",
       "   14.  \"f_9\" 11.000000 ##\n",
       "   15. \"f_14\"  9.000000 #\n",
       "   16.  \"f_2\"  8.000000 #\n",
       "   17. \"f_18\"  7.000000 #\n",
       "   18.  \"f_1\"  6.000000 \n",
       "   19. \"f_15\"  6.000000 \n",
       "   20. \"f_17\"  6.000000 \n",
       "   21.  \"f_6\"  6.000000 \n",
       "   22.  \"f_7\"  6.000000 \n",
       "   23. \"f_11\"  5.000000 \n",
       "   24. \"f_13\"  4.000000 \n",
       "   25.  \"f_5\"  3.000000 \n",
       "\n",
       "Variable Importance: SUM_SCORE:\n",
       "    1.  \"f_8\" 5461.255446 ################\n",
       "    2. \"f_23\" 4373.515720 ############\n",
       "    3.  \"f_3\" 3089.801418 ########\n",
       "    4. \"f_16\" 2792.764100 ########\n",
       "    5. \"f_19\" 1962.633235 #####\n",
       "    6. \"f_22\" 1559.521782 ####\n",
       "    7. \"f_21\" 1506.085651 ####\n",
       "    8. \"f_20\" 1500.190055 ####\n",
       "    9. \"f_25\" 1443.877104 ####\n",
       "   10.  \"f_2\" 1120.306174 ###\n",
       "   11. \"f_12\" 966.324088 ##\n",
       "   12. \"f_24\" 733.720112 #\n",
       "   13.  \"f_6\" 676.839801 #\n",
       "   14.  \"f_4\" 658.593045 #\n",
       "   15. \"f_10\" 639.683730 #\n",
       "   16.  \"f_9\" 628.532934 #\n",
       "   17. \"f_14\" 508.548441 #\n",
       "   18.  \"f_1\" 505.427016 #\n",
       "   19. \"f_11\" 479.768584 #\n",
       "   20.  \"f_7\" 429.637189 #\n",
       "   21. \"f_18\" 398.973265 \n",
       "   22. \"f_15\" 258.680338 \n",
       "   23.  \"f_5\" 249.870058 \n",
       "   24. \"f_17\" 127.890883 \n",
       "   25. \"f_13\" 83.563594 \n",
       "\n",
       "\n",
       "\n",
       "Loss: LAMBDA_MART_NDCG5\n",
       "Validation loss value: -0.456362\n",
       "Number of trees per iteration: 1\n",
       "Node format: NOT_SET\n",
       "Number of trees: 22\n",
       "Total number of nodes: 782\n",
       "\n",
       "Number of nodes by tree:\n",
       "Count: 22 Average: 35.5455 StdDev: 6.61079\n",
       "Min: 21 Max: 47 Ignored: 0\n",
       "----------------------------------------------\n",
       "[ 21, 22) 1   4.55%   4.55% ##\n",
       "[ 22, 23) 0   0.00%   4.55%\n",
       "[ 23, 25) 0   0.00%   4.55%\n",
       "[ 25, 26) 0   0.00%   4.55%\n",
       "[ 26, 27) 0   0.00%   4.55%\n",
       "[ 27, 29) 0   0.00%   4.55%\n",
       "[ 29, 30) 5  22.73%  27.27% ##########\n",
       "[ 30, 31) 0   0.00%  27.27%\n",
       "[ 31, 33) 1   4.55%  31.82% ##\n",
       "[ 33, 34) 3  13.64%  45.45% ######\n",
       "[ 34, 35) 0   0.00%  45.45%\n",
       "[ 35, 37) 2   9.09%  54.55% ####\n",
       "[ 37, 38) 2   9.09%  63.64% ####\n",
       "[ 38, 39) 0   0.00%  63.64%\n",
       "[ 39, 41) 2   9.09%  72.73% ####\n",
       "[ 41, 42) 2   9.09%  81.82% ####\n",
       "[ 42, 43) 0   0.00%  81.82%\n",
       "[ 43, 45) 1   4.55%  86.36% ##\n",
       "[ 45, 46) 1   4.55%  90.91% ##\n",
       "[ 46, 47] 2   9.09% 100.00% ####\n",
       "\n",
       "Depth by leafs:\n",
       "Count: 402 Average: 4.50498 StdDev: 0.820281\n",
       "Min: 1 Max: 5 Ignored: 0\n",
       "----------------------------------------------\n",
       "[ 1, 2)   1   0.25%   0.25%\n",
       "[ 2, 3)  13   3.23%   3.48%\n",
       "[ 3, 4)  40   9.95%  13.43% #\n",
       "[ 4, 5)  76  18.91%  32.34% ###\n",
       "[ 5, 5] 272  67.66% 100.00% ##########\n",
       "\n",
       "Number of training obs by leaf:\n",
       "Count: 402 Average: 562.751 StdDev: 2014.8\n",
       "Min: 5 Max: 9889 Ignored: 0\n",
       "----------------------------------------------\n",
       "[    5,  499) 366  91.04%  91.04% ##########\n",
       "[  499,  993)   6   1.49%  92.54%\n",
       "[  993, 1487)   1   0.25%  92.79%\n",
       "[ 1487, 1982)   1   0.25%  93.03%\n",
       "[ 1982, 2476)   0   0.00%  93.03%\n",
       "[ 2476, 2970)   5   1.24%  94.28%\n",
       "[ 2970, 3464)   0   0.00%  94.28%\n",
       "[ 3464, 3959)   2   0.50%  94.78%\n",
       "[ 3959, 4453)   0   0.00%  94.78%\n",
       "[ 4453, 4947)   0   0.00%  94.78%\n",
       "[ 4947, 5441)   0   0.00%  94.78%\n",
       "[ 5441, 5936)   0   0.00%  94.78%\n",
       "[ 5936, 6430)   0   0.00%  94.78%\n",
       "[ 6430, 6924)   0   0.00%  94.78%\n",
       "[ 6924, 7418)   6   1.49%  96.27%\n",
       "[ 7418, 7913)   0   0.00%  96.27%\n",
       "[ 7913, 8407)   0   0.00%  96.27%\n",
       "[ 8407, 8901)   0   0.00%  96.27%\n",
       "[ 8901, 9395)   6   1.49%  97.76%\n",
       "[ 9395, 9889]   9   2.24% 100.00%\n",
       "\n",
       "Attribute in nodes:\n",
       "\t58 : f_8 [NUMERICAL]\n",
       "\t33 : f_3 [NUMERICAL]\n",
       "\t30 : f_19 [NUMERICAL]\n",
       "\t29 : f_23 [NUMERICAL]\n",
       "\t23 : f_25 [NUMERICAL]\n",
       "\t20 : f_21 [NUMERICAL]\n",
       "\t19 : f_22 [NUMERICAL]\n",
       "\t18 : f_20 [NUMERICAL]\n",
       "\t18 : f_16 [NUMERICAL]\n",
       "\t17 : f_4 [NUMERICAL]\n",
       "\t16 : f_24 [NUMERICAL]\n",
       "\t11 : f_9 [NUMERICAL]\n",
       "\t11 : f_12 [NUMERICAL]\n",
       "\t11 : f_10 [NUMERICAL]\n",
       "\t9 : f_14 [NUMERICAL]\n",
       "\t8 : f_2 [NUMERICAL]\n",
       "\t7 : f_18 [NUMERICAL]\n",
       "\t6 : f_7 [NUMERICAL]\n",
       "\t6 : f_6 [NUMERICAL]\n",
       "\t6 : f_17 [NUMERICAL]\n",
       "\t6 : f_15 [NUMERICAL]\n",
       "\t6 : f_1 [NUMERICAL]\n",
       "\t5 : f_11 [NUMERICAL]\n",
       "\t4 : f_13 [NUMERICAL]\n",
       "\t3 : f_5 [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 0:\n",
       "\t13 : f_8 [NUMERICAL]\n",
       "\t4 : f_23 [NUMERICAL]\n",
       "\t3 : f_3 [NUMERICAL]\n",
       "\t1 : f_12 [NUMERICAL]\n",
       "\t1 : f_10 [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 1:\n",
       "\t23 : f_8 [NUMERICAL]\n",
       "\t8 : f_3 [NUMERICAL]\n",
       "\t8 : f_21 [NUMERICAL]\n",
       "\t6 : f_22 [NUMERICAL]\n",
       "\t5 : f_16 [NUMERICAL]\n",
       "\t4 : f_23 [NUMERICAL]\n",
       "\t3 : f_2 [NUMERICAL]\n",
       "\t1 : f_4 [NUMERICAL]\n",
       "\t1 : f_25 [NUMERICAL]\n",
       "\t1 : f_20 [NUMERICAL]\n",
       "\t1 : f_19 [NUMERICAL]\n",
       "\t1 : f_15 [NUMERICAL]\n",
       "\t1 : f_12 [NUMERICAL]\n",
       "\t1 : f_10 [NUMERICAL]\n",
       "\t1 : f_1 [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 2:\n",
       "\t29 : f_8 [NUMERICAL]\n",
       "\t25 : f_3 [NUMERICAL]\n",
       "\t12 : f_21 [NUMERICAL]\n",
       "\t12 : f_16 [NUMERICAL]\n",
       "\t10 : f_23 [NUMERICAL]\n",
       "\t9 : f_22 [NUMERICAL]\n",
       "\t8 : f_19 [NUMERICAL]\n",
       "\t5 : f_4 [NUMERICAL]\n",
       "\t5 : f_25 [NUMERICAL]\n",
       "\t5 : f_2 [NUMERICAL]\n",
       "\t4 : f_12 [NUMERICAL]\n",
       "\t3 : f_24 [NUMERICAL]\n",
       "\t3 : f_11 [NUMERICAL]\n",
       "\t2 : f_6 [NUMERICAL]\n",
       "\t1 : f_7 [NUMERICAL]\n",
       "\t1 : f_20 [NUMERICAL]\n",
       "\t1 : f_15 [NUMERICAL]\n",
       "\t1 : f_14 [NUMERICAL]\n",
       "\t1 : f_10 [NUMERICAL]\n",
       "\t1 : f_1 [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 3:\n",
       "\t47 : f_8 [NUMERICAL]\n",
       "\t32 : f_3 [NUMERICAL]\n",
       "\t24 : f_23 [NUMERICAL]\n",
       "\t18 : f_19 [NUMERICAL]\n",
       "\t15 : f_16 [NUMERICAL]\n",
       "\t13 : f_25 [NUMERICAL]\n",
       "\t13 : f_22 [NUMERICAL]\n",
       "\t13 : f_21 [NUMERICAL]\n",
       "\t12 : f_4 [NUMERICAL]\n",
       "\t8 : f_12 [NUMERICAL]\n",
       "\t7 : f_24 [NUMERICAL]\n",
       "\t7 : f_20 [NUMERICAL]\n",
       "\t7 : f_2 [NUMERICAL]\n",
       "\t5 : f_9 [NUMERICAL]\n",
       "\t4 : f_11 [NUMERICAL]\n",
       "\t3 : f_17 [NUMERICAL]\n",
       "\t3 : f_10 [NUMERICAL]\n",
       "\t3 : f_1 [NUMERICAL]\n",
       "\t2 : f_7 [NUMERICAL]\n",
       "\t2 : f_6 [NUMERICAL]\n",
       "\t2 : f_15 [NUMERICAL]\n",
       "\t2 : f_13 [NUMERICAL]\n",
       "\t1 : f_18 [NUMERICAL]\n",
       "\t1 : f_14 [NUMERICAL]\n",
       "\n",
       "Attribute in nodes with depth <= 5:\n",
       "\t58 : f_8 [NUMERICAL]\n",
       "\t33 : f_3 [NUMERICAL]\n",
       "\t30 : f_19 [NUMERICAL]\n",
       "\t29 : f_23 [NUMERICAL]\n",
       "\t23 : f_25 [NUMERICAL]\n",
       "\t20 : f_21 [NUMERICAL]\n",
       "\t19 : f_22 [NUMERICAL]\n",
       "\t18 : f_20 [NUMERICAL]\n",
       "\t18 : f_16 [NUMERICAL]\n",
       "\t17 : f_4 [NUMERICAL]\n",
       "\t16 : f_24 [NUMERICAL]\n",
       "\t11 : f_9 [NUMERICAL]\n",
       "\t11 : f_12 [NUMERICAL]\n",
       "\t11 : f_10 [NUMERICAL]\n",
       "\t9 : f_14 [NUMERICAL]\n",
       "\t8 : f_2 [NUMERICAL]\n",
       "\t7 : f_18 [NUMERICAL]\n",
       "\t6 : f_7 [NUMERICAL]\n",
       "\t6 : f_6 [NUMERICAL]\n",
       "\t6 : f_17 [NUMERICAL]\n",
       "\t6 : f_15 [NUMERICAL]\n",
       "\t6 : f_1 [NUMERICAL]\n",
       "\t5 : f_11 [NUMERICAL]\n",
       "\t4 : f_13 [NUMERICAL]\n",
       "\t3 : f_5 [NUMERICAL]\n",
       "\n",
       "Condition type in nodes:\n",
       "\t380 : HigherCondition\n",
       "Condition type in nodes with depth <= 0:\n",
       "\t22 : HigherCondition\n",
       "Condition type in nodes with depth <= 1:\n",
       "\t65 : HigherCondition\n",
       "Condition type in nodes with depth <= 2:\n",
       "\t138 : HigherCondition\n",
       "Condition type in nodes with depth <= 3:\n",
       "\t244 : HigherCondition\n",
       "Condition type in nodes with depth <= 5:\n",
       "\t380 : HigherCondition\n",
       "\n"
      ]
     }
   ],
   "source": [
    "# Print the summary of the model\n",
    "%set_cell_height 400\n",
    "\n",
    "model_8.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "train_models_with_tensorFlow_decision_forests.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "environment": {
   "kernel": "python3",
   "name": "tf2-gpu.2-6.m87",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-6:m87"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
